winglian commited on
Commit
132eb74
1 Parent(s): 5ed2939

DBRX Model Support (#1462)

Browse files

* wip for dbrx finetuning

* add fastcore for parallel loading of sharded weights

* fix dtype for load, use PartialState instead of accelerator to init process group, remove redundant wandb callback

* update to use v2 of the converted model

* more fixes for dbrx loras

* make sure to enable fsdp activation checkpointing

* fix support for 8bit loras too for dbrx

* apply z3 leaf moe fix for DBRX with deepspeed

* don't raise value error since child module searches could fail and be ok

* revert a previous change to fix fsdp

* update mistral/mistral qlora+fsdp yamls

* fix qlora+fsdp quant storage type

* more edge cases for qlora-fsdp

* fixes for fsdp+qlora w optimizer in 8bit

* add bigstral z3 config and make sure to use full_state_dict for fsdp

deepspeed_configs/zero3_bf16_cpuoffload_all.json CHANGED
@@ -1,4 +1,6 @@
1
  {
 
 
2
  "zero_optimization": {
3
  "stage": 3,
4
  "offload_optimizer": {
 
1
  {
2
+ "zero_force_ds_cpu_optimizer": false,
3
+ "zero_allow_untested_optimizer": true,
4
  "zero_optimization": {
5
  "stage": 3,
6
  "offload_optimizer": {
deepspeed_configs/zero3_bf16_cpuoffload_params.json CHANGED
@@ -1,4 +1,6 @@
1
  {
 
 
2
  "zero_optimization": {
3
  "stage": 3,
4
  "offload_param": {
 
1
  {
2
+ "zero_force_ds_cpu_optimizer": false,
3
+ "zero_allow_untested_optimizer": true,
4
  "zero_optimization": {
5
  "stage": 3,
6
  "offload_param": {
examples/dbrx/16bit-lora.yaml ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: LnL-AI/dbrx-base-converted-v2
2
+ trust_remote_code: true
3
+
4
+ load_in_8bit: false
5
+ load_in_4bit: false
6
+ strict: false
7
+
8
+ datasets:
9
+ - path: tatsu-lab/alpaca
10
+ type: alpaca
11
+ dataset_prepared_path: last_run_prepared
12
+ val_set_size: 0.0
13
+ output_dir: ./out
14
+
15
+ sequence_len: 512
16
+ sample_packing: false
17
+ pad_to_sequence_len: false
18
+
19
+ wandb_project:
20
+ wandb_entity:
21
+ wandb_watch:
22
+ wandb_name:
23
+ wandb_log_model:
24
+
25
+ adapter: lora
26
+ lora_model_dir:
27
+ lora_r: 8
28
+ lora_alpha: 16
29
+ lora_dropout: 0.05
30
+ # w1, w2, & v1 will hang the trainer
31
+ lora_target_modules:
32
+ - q_proj # attn
33
+ - k_proj # attn
34
+ - v_proj # attn
35
+ - out_proj # attn
36
+ - layer # router
37
+ # - w1
38
+ # - w2
39
+ # - v1
40
+
41
+ gradient_accumulation_steps: 1
42
+ micro_batch_size: 1
43
+ num_epochs: 1
44
+ optimizer: paged_adamw_8bit
45
+ lr_scheduler: cosine
46
+ learning_rate: 0.0002
47
+
48
+ train_on_inputs: false
49
+ group_by_length: false
50
+ bf16: auto
51
+ fp16:
52
+ tf32: false
53
+
54
+ gradient_checkpointing: false # don't use with fsdp_activation_checkpointing
55
+ gradient_checkpointing_kwargs:
56
+ use_reentrant: false
57
+ early_stopping_patience:
58
+ resume_from_checkpoint:
59
+ local_rank:
60
+ logging_steps: 1
61
+ xformers_attention:
62
+ flash_attention: true
63
+
64
+ warmup_steps: 10
65
+ evals_per_epoch:
66
+ saves_per_epoch: 1
67
+ debug:
68
+ weight_decay: 0.0
69
+ fsdp:
70
+ - full_shard
71
+ - auto_wrap
72
+ fsdp_config:
73
+ fsdp_limit_all_gathers: true
74
+ fsdp_sync_module_states: true
75
+ fsdp_offload_params: false
76
+ fsdp_use_orig_params: false
77
+ fsdp_cpu_ram_efficient_loading: true
78
+ fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
79
+ fsdp_transformer_layer_cls_to_wrap: DbrxBlock
80
+ fsdp_state_dict_type: FULL_STATE_DICT
81
+ fsdp_activation_checkpointing: true
examples/dbrx/8bit-lora.yaml ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: LnL-AI/dbrx-base-converted-v2
2
+ trust_remote_code: true
3
+
4
+ load_in_8bit: true
5
+ load_in_4bit: false
6
+ strict: false
7
+
8
+ datasets:
9
+ - path: tatsu-lab/alpaca
10
+ type: alpaca
11
+ dataset_prepared_path: last_run_prepared
12
+ val_set_size: 0.0
13
+ output_dir: ./out
14
+
15
+ sequence_len: 512
16
+ sample_packing: false
17
+ pad_to_sequence_len: false
18
+
19
+ wandb_project:
20
+ wandb_entity:
21
+ wandb_watch:
22
+ wandb_name:
23
+ wandb_log_model:
24
+
25
+ adapter: lora
26
+ lora_model_dir:
27
+ lora_r: 8
28
+ lora_alpha: 16
29
+ lora_dropout: 0.05
30
+ # w1, w2, & v1 will hang the trainer
31
+ lora_target_modules:
32
+ - q_proj # attn
33
+ - k_proj # attn
34
+ - v_proj # attn
35
+ - out_proj # attn
36
+ - layer # router
37
+ # - w1
38
+ # - w2
39
+ # - v1
40
+
41
+ gradient_accumulation_steps: 1
42
+ micro_batch_size: 1
43
+ num_epochs: 1
44
+ optimizer: paged_adamw_8bit
45
+ lr_scheduler: cosine
46
+ learning_rate: 0.0002
47
+
48
+ train_on_inputs: false
49
+ group_by_length: false
50
+ bf16: auto
51
+ fp16:
52
+ tf32: false
53
+
54
+ gradient_checkpointing: false # don't use with fsdp_activation_checkpointing
55
+ gradient_checkpointing_kwargs:
56
+ use_reentrant: false
57
+ early_stopping_patience:
58
+ resume_from_checkpoint:
59
+ local_rank:
60
+ logging_steps: 1
61
+ xformers_attention:
62
+ flash_attention: true
63
+
64
+ warmup_steps: 10
65
+ evals_per_epoch:
66
+ saves_per_epoch: 1
67
+ debug:
68
+ weight_decay: 0.0
69
+ fsdp:
70
+ - full_shard
71
+ - auto_wrap
72
+ fsdp_config:
73
+ fsdp_limit_all_gathers: true
74
+ fsdp_sync_module_states: true
75
+ fsdp_offload_params: false
76
+ fsdp_use_orig_params: false
77
+ fsdp_cpu_ram_efficient_loading: true
78
+ fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
79
+ fsdp_transformer_layer_cls_to_wrap: DbrxBlock
80
+ fsdp_state_dict_type: FULL_STATE_DICT
81
+ fsdp_activation_checkpointing: true
examples/dbrx/README.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DBRX MoE
2
+
3
+ Currently, for LoRA, only the `q_proj`, `k_proj`, `v_proj` `out_proj` and `layer` Linear layers are trainable.
4
+
5
+ We are using the "converted" base models based on [this issue](https://huggingface.co/databricks/dbrx-instruct/discussions/10)
6
+ where the Experts are fused as an `nn.Parameter` rather than a `nn.Linear` layer. However, the implementation
7
+ is still a bit buggy and attempting to train a LoRA adapter over those `w1`, `w2` and `v1` layers
8
+ results in the trainer hanging.
9
+
10
+
11
+ ### FSDP
12
+ We've tested using the [`LnL-AI/dbrx-base-converted-v2`](https://huggingface.co/LnL-AI/dbrx-base-converted-v2) model as the base model for FSDP.
13
+
14
+ The high memory usage seen w/ FSDP is due to FSDP not supporting 8bit optimizers.
15
+
16
+ - 16-bit LoRA w/ FSDP
17
+ - ✅ w/o CPU Offload - 8x80GB uses ~80GiB/gpu
18
+ - ❌ w/ CPU Offload - `paged_adamw_8bit` optimizer errors from being on cpu
19
+ - ✅ 8-bit LoRA w/ FSDP
20
+ - ❌ 4-bit QLoRA w/ FSDP - errors w/: `Error an illegal memory access was encountered at line 90 in file /src/csrc/ops.cu`
21
+ - ✅ bf16 full finetune w/ FSDP, freezing all but first 8 layers (8x80GB uses ~78GiB/gpu)
22
+
23
+
24
+ ### Deepspeed
25
+
26
+ WIP
examples/dbrx/fft-ds-zero3.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: LnL-AI/dbrx-base-converted-v2
2
+ trust_remote_code: true
3
+
4
+ load_in_8bit: false
5
+ load_in_4bit: false
6
+ strict: false
7
+
8
+ datasets:
9
+ - path: tatsu-lab/alpaca
10
+ type: alpaca
11
+ dataset_prepared_path: last_run_prepared
12
+ val_set_size: 0.0
13
+ output_dir: ./out
14
+
15
+ sequence_len: 512
16
+ sample_packing: false
17
+ pad_to_sequence_len: false
18
+
19
+ unfrozen_parameters:
20
+ - transformer.blocks.[0-7].
21
+
22
+ wandb_project:
23
+ wandb_entity:
24
+ wandb_watch:
25
+ wandb_name:
26
+ wandb_log_model:
27
+
28
+ gradient_accumulation_steps: 1
29
+ micro_batch_size: 1
30
+ num_epochs: 1
31
+ optimizer: paged_adamw_8bit
32
+ lr_scheduler: cosine
33
+ learning_rate: 0.0002
34
+
35
+ train_on_inputs: false
36
+ group_by_length: false
37
+ bf16: auto
38
+ fp16:
39
+ tf32: false
40
+
41
+ gradient_checkpointing: true
42
+ gradient_checkpointing_kwargs:
43
+ use_reentrant: false
44
+ early_stopping_patience:
45
+ resume_from_checkpoint:
46
+ local_rank:
47
+ logging_steps: 1
48
+ xformers_attention:
49
+ flash_attention: true
50
+
51
+ warmup_steps: 10
52
+ evals_per_epoch:
53
+ saves_per_epoch: 1
54
+ debug:
55
+ weight_decay: 0.0
56
+ deepspeed: deepspeed_configs/zero3_bf16.json
examples/llama-2/qlora-fsdp.yml CHANGED
@@ -65,12 +65,14 @@ deepspeed:
65
  weight_decay: 0.0
66
  fsdp:
67
  - full_shard
 
68
  fsdp_config:
69
  fsdp_limit_all_gathers: true
70
  fsdp_sync_module_states: true
71
  fsdp_offload_params: true
72
  fsdp_use_orig_params: false
73
  fsdp_cpu_ram_efficient_loading: true
 
74
  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
75
- fsdp_state_dict_type: SHARDED_STATE_DICT
76
  special_tokens:
 
65
  weight_decay: 0.0
66
  fsdp:
67
  - full_shard
68
+ - auto_wrap
69
  fsdp_config:
70
  fsdp_limit_all_gathers: true
71
  fsdp_sync_module_states: true
72
  fsdp_offload_params: true
73
  fsdp_use_orig_params: false
74
  fsdp_cpu_ram_efficient_loading: true
75
+ fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
76
  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
77
+ fsdp_state_dict_type: FULL_STATE_DICT
78
  special_tokens:
examples/mistral/bigstral-ds-zero3.yaml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: mistral-community/Mixtral-8x22B-v0.1
2
+ model_type: AutoModelForCausalLM
3
+ tokenizer_type: LlamaTokenizer
4
+ trust_remote_code: true
5
+
6
+ load_in_8bit: false
7
+ load_in_4bit: false
8
+ strict: false
9
+
10
+ unfrozen_parameters:
11
+ - ^lm_head.weight$
12
+ - ^model.embed_tokens.weight$
13
+ - model.layers.4[4-9]+.block_sparse_moe.gate
14
+ - model.layers.4[4-9]+.block_sparse_moe.experts
15
+ - model.layers.5[0-5]+.block_sparse_moe.gate
16
+ - model.layers.5[0-5]+.block_sparse_moe.experts
17
+
18
+ model_config:
19
+ output_router_logits: true
20
+
21
+ datasets:
22
+ - path: tatsu-lab/alpaca
23
+ type: alpaca
24
+ dataset_prepared_path: last_run_prepared
25
+ val_set_size: 0.05
26
+ output_dir: ./out
27
+
28
+ sequence_len: 2048
29
+ sample_packing: true
30
+ pad_to_sequence_len: true
31
+
32
+ gradient_accumulation_steps: 1
33
+ micro_batch_size: 1
34
+ num_epochs: 3
35
+ optimizer: adamw_bnb_8bit
36
+ lr_scheduler: cosine
37
+ learning_rate: 0.0001
38
+
39
+ train_on_inputs: false
40
+ group_by_length: false
41
+ bf16: auto
42
+ fp16:
43
+ tf32: false
44
+
45
+ gradient_checkpointing: true
46
+ early_stopping_patience:
47
+ resume_from_checkpoint:
48
+ local_rank:
49
+ logging_steps: 1
50
+ xformers_attention:
51
+ flash_attention: true
52
+
53
+ save_total_limit: 1
54
+ save_steps:
55
+ debug:
56
+ deepspeed: deepspeed_configs/zero3_bf16_cpuoffload_params.json
57
+ weight_decay: 0.0
58
+ fsdp:
59
+ fsdp_config:
60
+ special_tokens:
61
+ eos_token: "<|im_end|>"
62
+ tokens:
63
+ - "<|im_start|>"
examples/mistral/mistral-qlora-fsdp.yml ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: mistralai/Mixtral-8x7B-v0.1
2
+ model_type: AutoModelForCausalLM
3
+ tokenizer_type: LlamaTokenizer
4
+ trust_remote_code: true
5
+
6
+ load_in_8bit: false
7
+ load_in_4bit: true
8
+ strict: false
9
+
10
+ datasets:
11
+ - path: tatsu-lab/alpaca
12
+ type: alpaca
13
+ dataset_prepared_path: last_run_prepared
14
+ val_set_size: 0.02
15
+ output_dir: ./qlora-out
16
+
17
+ model_config:
18
+ output_router_logits: true
19
+
20
+ adapter: qlora
21
+ lora_model_dir:
22
+
23
+ sequence_len: 1024
24
+ sample_packing: false
25
+ pad_to_sequence_len: false
26
+
27
+ lora_r: 32
28
+ lora_alpha: 16
29
+ lora_dropout: 0.05
30
+ lora_target_linear: true
31
+ lora_fan_in_fan_out:
32
+
33
+ wandb_project:
34
+ wandb_entity:
35
+ wandb_watch:
36
+ wandb_name:
37
+ wandb_log_model:
38
+
39
+ gradient_accumulation_steps: 4
40
+ micro_batch_size: 2
41
+ num_epochs: 1
42
+ optimizer: paged_adamw_8bit
43
+ lr_scheduler: cosine
44
+ learning_rate: 0.0002
45
+
46
+ train_on_inputs: false
47
+ group_by_length: false
48
+ bf16: auto
49
+ fp16:
50
+ tf32: false
51
+
52
+ gradient_checkpointing: true
53
+ early_stopping_patience:
54
+ resume_from_checkpoint:
55
+ local_rank:
56
+ logging_steps: 1
57
+ xformers_attention:
58
+ flash_attention: true
59
+
60
+ loss_watchdog_threshold: 5.0
61
+ loss_watchdog_patience: 3
62
+
63
+ warmup_steps: 10
64
+ evals_per_epoch: 4
65
+ eval_table_size:
66
+ eval_max_new_tokens: 128
67
+ saves_per_epoch: 1
68
+ debug:
69
+ weight_decay: 0.0
70
+ fsdp:
71
+ - full_shard
72
+ - auto_wrap
73
+ fsdp_config:
74
+ fsdp_limit_all_gathers: true
75
+ fsdp_sync_module_states: true
76
+ fsdp_offload_params: false
77
+ fsdp_use_orig_params: false
78
+ fsdp_cpu_ram_efficient_loading: false
79
+ fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
80
+ fsdp_state_dict_type: FULL_STATE_DICT
81
+ fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
82
+ special_tokens:
examples/mistral/mixtral-8x22b-qlora-fsdp.yml ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: mistral-community/Mixtral-8x22B-v0.1
2
+ model_type: AutoModelForCausalLM
3
+ tokenizer_type: LlamaTokenizer
4
+
5
+ load_in_8bit: false
6
+ load_in_4bit: true
7
+ strict: false
8
+
9
+ datasets:
10
+ - path: tatsu-lab/alpaca
11
+ type: alpaca
12
+ dataset_prepared_path: last_run_prepared
13
+ val_set_size: 0.02
14
+ output_dir: ./qlora-out
15
+
16
+ model_config:
17
+ output_router_logits: true
18
+
19
+ adapter: qlora
20
+ lora_model_dir:
21
+
22
+ sequence_len: 1024
23
+ sample_packing: false
24
+ pad_to_sequence_len: false
25
+
26
+ lora_r: 32
27
+ lora_alpha: 16
28
+ lora_dropout: 0.05
29
+ lora_target_linear: true
30
+ lora_fan_in_fan_out:
31
+
32
+ wandb_project:
33
+ wandb_entity:
34
+ wandb_watch:
35
+ wandb_name:
36
+ wandb_log_model:
37
+
38
+ gradient_accumulation_steps: 4
39
+ micro_batch_size: 2
40
+ num_epochs: 1
41
+ optimizer: adamw_torch
42
+ lr_scheduler: cosine
43
+ learning_rate: 0.0002
44
+
45
+ train_on_inputs: false
46
+ group_by_length: false
47
+ bf16: auto
48
+ fp16:
49
+ tf32: true
50
+
51
+ gradient_checkpointing: true
52
+ early_stopping_patience:
53
+ resume_from_checkpoint:
54
+ local_rank:
55
+ logging_steps: 1
56
+ xformers_attention:
57
+ flash_attention: true
58
+
59
+ loss_watchdog_threshold: 5.0
60
+ loss_watchdog_patience: 3
61
+
62
+ warmup_steps: 10
63
+ evals_per_epoch: 4
64
+ eval_table_size:
65
+ eval_max_new_tokens: 128
66
+ saves_per_epoch: 1
67
+ debug:
68
+ weight_decay: 0.0
69
+ fsdp:
70
+ - full_shard
71
+ - auto_wrap
72
+ fsdp_config:
73
+ fsdp_limit_all_gathers: true
74
+ fsdp_sync_module_states: true
75
+ fsdp_offload_params: true
76
+ fsdp_use_orig_params: false
77
+ fsdp_cpu_ram_efficient_loading: true
78
+ fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
79
+ fsdp_state_dict_type: FULL_STATE_DICT
80
+ fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
81
+ special_tokens:
examples/mistral/mixtral-qlora-fsdp.yml CHANGED
@@ -39,7 +39,7 @@ wandb_log_model:
39
  gradient_accumulation_steps: 4
40
  micro_batch_size: 2
41
  num_epochs: 1
42
- optimizer: paged_adamw_8bit
43
  lr_scheduler: cosine
44
  learning_rate: 0.0002
45
 
@@ -47,7 +47,7 @@ train_on_inputs: false
47
  group_by_length: false
48
  bf16: auto
49
  fp16:
50
- tf32: false
51
 
52
  gradient_checkpointing: true
53
  early_stopping_patience:
@@ -69,6 +69,17 @@ debug:
69
  weight_decay: 0.0
70
  fsdp:
71
  - full_shard
 
72
  fsdp_config:
 
 
 
 
 
73
  fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
 
 
 
 
 
74
  special_tokens:
 
39
  gradient_accumulation_steps: 4
40
  micro_batch_size: 2
41
  num_epochs: 1
42
+ optimizer: adamw_torch
43
  lr_scheduler: cosine
44
  learning_rate: 0.0002
45
 
 
47
  group_by_length: false
48
  bf16: auto
49
  fp16:
50
+ tf32: true
51
 
52
  gradient_checkpointing: true
53
  early_stopping_patience:
 
69
  weight_decay: 0.0
70
  fsdp:
71
  - full_shard
72
+ - auto_wrap
73
  fsdp_config:
74
+ fsdp_limit_all_gathers: true
75
+ fsdp_sync_module_states: true
76
+ fsdp_offload_params: true
77
+ fsdp_use_orig_params: false
78
+ fsdp_cpu_ram_efficient_loading: true
79
  fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
80
+ fsdp_state_dict_type: FULL_STATE_DICT
81
+ fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
82
+ fsdp_sharding_strategy: FULL_SHARD
83
+ fsdp_forward_prefetch: false
84
+ fsdp_backward_prefetch: BACKWARD_PRE
85
  special_tokens:
requirements.txt CHANGED
@@ -41,3 +41,4 @@ gcsfs
41
 
42
  trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
43
  zstandard==0.22.0
 
 
41
 
42
  trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
43
  zstandard==0.22.0
44
+ fastcore
src/axolotl/core/trainer_builder.py CHANGED
@@ -918,10 +918,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
918
  ):
919
  callbacks.append(SaveBetterTransformerModelCallback())
920
 
921
- if self.cfg.use_wandb:
922
- callbacks.append(
923
- SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
924
- )
925
  if self.cfg.use_mlflow and is_mlflow_available():
926
  from axolotl.utils.callbacks.mlflow_ import (
927
  SaveAxolotlConfigtoMlflowCallback,
 
918
  ):
919
  callbacks.append(SaveBetterTransformerModelCallback())
920
 
 
 
 
 
921
  if self.cfg.use_mlflow and is_mlflow_available():
922
  from axolotl.utils.callbacks.mlflow_ import (
923
  SaveAxolotlConfigtoMlflowCallback,
src/axolotl/train.py CHANGED
@@ -9,6 +9,7 @@ from typing import Optional, Tuple, Union
9
 
10
  import torch
11
  import transformers.modelcard
 
12
  from accelerate.logging import get_logger
13
  from datasets import Dataset
14
  from peft import PeftModel
@@ -81,6 +82,8 @@ def train(
81
  if cfg.adapter:
82
  msg += " and peft_config..."
83
  LOG.debug(msg)
 
 
84
  model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
85
  model.generation_config.do_sample = True
86
 
 
9
 
10
  import torch
11
  import transformers.modelcard
12
+ from accelerate import Accelerator
13
  from accelerate.logging import get_logger
14
  from datasets import Dataset
15
  from peft import PeftModel
 
82
  if cfg.adapter:
83
  msg += " and peft_config..."
84
  LOG.debug(msg)
85
+ # we wait unitl the last possible moment to setup Accelerator
86
+ Accelerator()
87
  model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
88
  model.generation_config.do_sample = True
89
 
src/axolotl/utils/config/models/input/v0_4_1/__init__.py CHANGED
@@ -259,6 +259,7 @@ class ModelInputConfig(BaseModel):
259
 
260
  base_model: str
261
  base_model_config: Optional[str] = None
 
262
  tokenizer_config: Optional[str] = None
263
  tokenizer_use_fast: Optional[bool] = None
264
  tokenizer_legacy: Optional[bool] = None
@@ -971,9 +972,16 @@ class AxolotlInputConfig(
971
 
972
  @model_validator(mode="before")
973
  @classmethod
974
- def check_fsdp_w_8bit_optimizer(cls, data):
975
- if data.get("fsdp") and "bnb" in data.get("optimizer", ""):
976
- raise ValueError(f"FSDP not compatible with {data.get('optimizer')}")
 
 
 
 
 
 
 
977
  return data
978
 
979
  @model_validator(mode="before")
 
259
 
260
  base_model: str
261
  base_model_config: Optional[str] = None
262
+ cls_model_config: Optional[str] = None
263
  tokenizer_config: Optional[str] = None
264
  tokenizer_use_fast: Optional[bool] = None
265
  tokenizer_legacy: Optional[bool] = None
 
972
 
973
  @model_validator(mode="before")
974
  @classmethod
975
+ def check_fsdp_offload_w_8bit_optimizer(cls, data):
976
+ if (
977
+ data.get("fsdp")
978
+ and "8bit" in data.get("optimizer", "")
979
+ and data.get("fsdp_config")
980
+ and data["fsdp_config"].get("fsdp_offload_params")
981
+ ):
982
+ raise ValueError(
983
+ f"FSDP Offload not compatible with {data.get('optimizer')}"
984
+ )
985
  return data
986
 
987
  @model_validator(mode="before")
src/axolotl/utils/distributed.py CHANGED
@@ -4,27 +4,25 @@ utility helpers for distributed checks
4
  import os
5
  import pickle # nosec
6
  from contextlib import contextmanager
 
7
 
8
  import torch
9
  import torch.distributed as dist
10
- from accelerate import Accelerator
11
 
12
- accelerate = None # pylint: disable=invalid-name
13
-
14
-
15
- def load_accelerate():
16
- global accelerate # pylint: disable=global-statement
17
- accelerate = Accelerator()
18
 
19
 
20
  def is_distributed():
21
  """
22
  Check if distributed training is initialized.
23
  """
24
- global accelerate # pylint: disable=global-statement
25
- if not accelerate:
26
- accelerate = Accelerator()
27
- return dist.is_available() and dist.is_initialized()
 
 
28
 
29
 
30
  def barrier():
 
4
  import os
5
  import pickle # nosec
6
  from contextlib import contextmanager
7
+ from datetime import timedelta
8
 
9
  import torch
10
  import torch.distributed as dist
11
+ from accelerate import PartialState
12
 
13
+ distributed_state = None # pylint: disable=invalid-name
 
 
 
 
 
14
 
15
 
16
  def is_distributed():
17
  """
18
  Check if distributed training is initialized.
19
  """
20
+ global distributed_state # pylint: disable=global-statement
21
+ if not distributed_state:
22
+ timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
23
+ distributed_state = PartialState(timeout=timedelta(seconds=timeout))
24
+
25
+ return distributed_state.use_distributed and distributed_state.initialized
26
 
27
 
28
  def barrier():
src/axolotl/utils/model_shard_quant.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ module to handle loading model on cpu/meta device for FSDP
3
+ """
4
+ import os
5
+ import time
6
+ from typing import List, Optional, Type, Union
7
+
8
+ import safetensors
9
+ import torch
10
+ from accelerate import init_empty_weights
11
+ from bitsandbytes.nn import Linear4bit, Params4bit
12
+ from fastcore.parallel import parallel
13
+ from torch import Tensor, nn
14
+ from tqdm import tqdm
15
+ from transformers import AutoModelForCausalLM
16
+ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
17
+
18
+
19
+ def _replace_linear(
20
+ model: nn.Module,
21
+ linear_replacement: Type[nn.Module],
22
+ quant_config: Union[dict, None] = None,
23
+ skip_modules=None,
24
+ **kwargs,
25
+ ):
26
+ """
27
+ Replace linear modules with a new Linear module.
28
+ Parameters:
29
+ model (`torch.nn.Module`):
30
+ Input model or `torch.nn.Module` as the function is run recursively.
31
+ linear_replacement (`torch.nn.Module`):
32
+ The linear module that replaces the old one. Only expects standard arguments.
33
+ If other arguments need to be passed, use a lambda.
34
+ skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
35
+ List of modules names not to convert. Defaults to `lm_head`.
36
+ """
37
+ if skip_modules is None:
38
+ skip_modules = ["lm_head"]
39
+ for name, module in model.named_children():
40
+ if len(list(module.children())) > 0:
41
+ _replace_linear(
42
+ module, linear_replacement, quant_config, skip_modules, **kwargs
43
+ )
44
+
45
+ if isinstance(module, torch.nn.Linear) and name not in skip_modules:
46
+ if issubclass(linear_replacement, Linear4bit):
47
+ model._modules[ # pylint: disable=protected-access
48
+ name
49
+ ] = linear_replacement(
50
+ module.in_features,
51
+ module.out_features,
52
+ module.bias is not None,
53
+ **kwargs,
54
+ )
55
+ else:
56
+ raise ValueError(
57
+ f"Unsupported linear replacement: {type(linear_replacement)}"
58
+ )
59
+ return model
60
+
61
+
62
+ def load_and_quantize(
63
+ module: nn.Module,
64
+ name: str,
65
+ value: Tensor,
66
+ device: torch.device = None,
67
+ dtype: torch.dtype = None,
68
+ skip_names: Optional[List[str]] = None,
69
+ to_cpu: bool = False,
70
+ to_meta: bool = False,
71
+ verbose: bool = False,
72
+ quant_method: str = "bnb",
73
+ ):
74
+ """
75
+ Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
76
+
77
+ Quantizes `Params4bit` on `device` then places on "cpu" if to_cpu=True or "meta" if to_meta=True.
78
+ """
79
+
80
+ if not skip_names:
81
+ skip_names = []
82
+
83
+ def place_on_device(value):
84
+ if to_meta:
85
+ device = "meta"
86
+ elif to_cpu:
87
+ device = "cpu"
88
+ return value.to(device=device, dtype=dtype)
89
+
90
+ if any(skip_name in name for skip_name in skip_names):
91
+ if verbose:
92
+ print(f"Skipping {name} because it is in skip_names")
93
+ return
94
+
95
+ module_key, _, value_key = name.rpartition(".")
96
+ try:
97
+ submodule = module.get_submodule(module_key)
98
+ except AttributeError as exc:
99
+ print(f"Module {module_key} not found:\n{exc}")
100
+ return
101
+
102
+ try:
103
+ if quant_method == "bnb":
104
+ param = submodule.get_parameter(value_key)
105
+ if isinstance(param, Params4bit):
106
+ # With `sync_module_states=True`, a meta device Params4bit needs to be the same
107
+ # shape as the quantized Params4bit with an initialized quant_state. However,
108
+ # FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
109
+ # workaround quantizes Params4bit to initialize quant_state on all ranks, then
110
+ # replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
111
+ value = type(param)(
112
+ value.to(device=device, dtype=dtype).data, **param.__dict__
113
+ ).cuda(device)
114
+ if to_meta:
115
+ value = type(param)(value.data.to("meta"), **value.__dict__)
116
+ elif to_cpu:
117
+ value = type(param)(value.data.to("cpu"), **value.__dict__)
118
+ else:
119
+ value = type(param)(place_on_device(value).data)
120
+
121
+ except AttributeError:
122
+ # it's a buffer
123
+ value = place_on_device(value)
124
+
125
+ setattr(submodule, value_key, value)
126
+
127
+
128
+ def n_loading_workers(quant_method: str, param_count: float):
129
+ devprops = torch.cuda.get_device_properties(torch.cuda.current_device())
130
+ left = int(os.cpu_count() / torch.cuda.device_count())
131
+ model_params_b = 70
132
+ right = int(
133
+ (4 if quant_method == "hqq" else 8)
134
+ * (devprops.total_memory / 1e9 / 40)
135
+ * (model_params_b / (param_count / 1e9))
136
+ )
137
+ return min(left, right)
138
+
139
+
140
+ def load_sharded_model(
141
+ model_name,
142
+ model_config,
143
+ cfg,
144
+ torch_dtype=torch.bfloat16,
145
+ low_memory=True,
146
+ ):
147
+ if (low_memory and cfg.local_rank == 0) or not low_memory:
148
+ model = AutoModelForCausalLM.from_pretrained(
149
+ model_name,
150
+ use_cache=False,
151
+ torch_dtype=torch.float32,
152
+ _attn_implementation=model_config._attn_implementation, # pylint: disable=protected-access
153
+ trust_remote_code=cfg.trust_remote_code,
154
+ )
155
+ dtype = torch_dtype if not cfg.float32 else None
156
+ model.to(dtype=dtype, device="cpu" if low_memory else cfg.local_rank)
157
+ else:
158
+ with init_empty_weights():
159
+ model = AutoModelForCausalLM.from_config(
160
+ model_config,
161
+ torch_dtype=torch_dtype,
162
+ trust_remote_code=cfg.trust_remote_code,
163
+ )
164
+ return model
165
+
166
+
167
+ def load_sharded_model_quant(
168
+ model_name,
169
+ model_config,
170
+ cfg,
171
+ compute_dtype=torch.bfloat16,
172
+ quant_storage=torch.float32,
173
+ low_memory=True,
174
+ verbose=False,
175
+ loading_workers=2,
176
+ ):
177
+ with init_empty_weights():
178
+ model = AutoModelForCausalLM.from_config(
179
+ model_config,
180
+ trust_remote_code=cfg.trust_remote_code,
181
+ )
182
+ if hasattr(model, "transformer"):
183
+ model.transformer = _replace_linear(
184
+ model.transformer,
185
+ Linear4bit,
186
+ compute_dtype=compute_dtype,
187
+ quant_type="nf4",
188
+ quant_storage=quant_storage,
189
+ )
190
+ else:
191
+ # this is the more common case with HF transformers
192
+ model.model = _replace_linear(
193
+ model.model,
194
+ Linear4bit,
195
+ compute_dtype=compute_dtype,
196
+ quant_type="nf4",
197
+ quant_storage=quant_storage,
198
+ )
199
+ model.is_loaded_in_4bit = True
200
+
201
+ # Grab the safetensors files that hold the weights
202
+ try:
203
+ idx = hub.cached_file(model_name, SAFE_WEIGHTS_INDEX_NAME)
204
+ files, _ = hub.get_checkpoint_shard_files(model_name, idx)
205
+ except OSError:
206
+ try:
207
+ # This means the model doesn't have a model.safetensors.index.json because it is not sharded
208
+ files = []
209
+ files.append(hub.cached_file(model_name, SAFE_WEIGHTS_NAME))
210
+ except OSError as exc:
211
+ # This means the model probably doesn't have a safetensors file
212
+ raise exc
213
+
214
+ # Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly
215
+ # and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage
216
+ def load_and_quantize_parallel(name_param, model, **kwargs):
217
+ name, param = name_param
218
+ load_and_quantize(model, name, param, **kwargs)
219
+
220
+ quant_method = "bnb"
221
+ param_count = sum((p.numel() for n, p in model.named_parameters()))
222
+
223
+ n_workers = (
224
+ n_loading_workers(quant_method, param_count)
225
+ if loading_workers == -1
226
+ else loading_workers
227
+ )
228
+ if cfg.local_rank == 0 and verbose:
229
+ print(f"Using n_workers: {n_workers} for loading")
230
+
231
+ start = time.time()
232
+ for filename in tqdm(
233
+ files,
234
+ desc="Loading & Quantizing Model Shards",
235
+ disable=cfg.local_rank != 0,
236
+ position=0,
237
+ ):
238
+ weights = safetensors.torch.load_file(filename)
239
+ parallel(
240
+ load_and_quantize_parallel,
241
+ iter(weights.items()),
242
+ n_workers=n_workers,
243
+ threadpool=True,
244
+ model=model,
245
+ dtype=quant_storage,
246
+ device=cfg.local_rank,
247
+ skip_names=[],
248
+ to_cpu=(low_memory and cfg.local_rank == 0),
249
+ to_meta=(low_memory and cfg.local_rank != 0),
250
+ verbose=verbose,
251
+ quant_method=quant_method,
252
+ )
253
+
254
+ if cfg.local_rank == 0 and verbose:
255
+ print(f"Loaded model weights in {time.time()-start:.3f} seconds")
256
+ # cleanup any extra memory usage from parallel loading
257
+ torch.cuda.empty_cache()
258
+
259
+ return model
src/axolotl/utils/models.py CHANGED
@@ -45,10 +45,35 @@ from axolotl.utils.chat_templates import chat_templates
45
  from axolotl.utils.dict import DictDefault
46
  from axolotl.utils.distributed import zero_only
47
  from axolotl.utils.lora_embeddings import get_linear_embedding_layers
 
48
 
49
  LOG = logging.getLogger("axolotl")
50
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
53
  quant_config_exists = (
54
  hasattr(model_config, "quantization_config")
@@ -459,7 +484,7 @@ def load_model(
459
  "bnb_4bit_quant_type": "nf4",
460
  "bnb_4bit_quant_storage": torch.bfloat16,
461
  }
462
- if not cfg.deepspeed:
463
  # for some reason, this causes the loss to be off by an order of magnitude
464
  # but deepspeed needs this still in bfloat16
465
  bnb_config["bnb_4bit_quant_storage"] = torch.float32
@@ -470,6 +495,13 @@ def load_model(
470
  model_kwargs["quantization_config"] = BitsAndBytesConfig(
471
  **bnb_config,
472
  )
 
 
 
 
 
 
 
473
 
474
  if cfg.load_in_8bit and cfg.adapter is not None:
475
  model_kwargs["load_in_8bit"] = True
@@ -517,7 +549,31 @@ def load_model(
517
  qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora"
518
 
519
  try:
 
520
  if (
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
521
  model_config.model_type == "llama"
522
  and not cfg.trust_remote_code
523
  and not cfg.gptq
@@ -597,6 +653,11 @@ def load_model(
597
  **model_kwargs,
598
  )
599
  else:
 
 
 
 
 
600
  model = AutoModelForCausalLM.from_pretrained(
601
  base_model,
602
  config=model_config,
@@ -670,13 +731,17 @@ def load_model(
670
  needs_fa2_dtype = cfg.adapter or cfg.fsdp
671
  skip_prepare_model_for_kbit_training = False
672
 
673
- if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled():
674
  from deepspeed.utils import ( # pylint: disable=no-name-in-module
675
  set_z3_leaf_modules,
676
  )
677
- from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
678
 
679
- set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
 
 
 
 
 
680
 
681
  if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
682
  # Qwen doesn't play nicely with LoRA if this is enabled
@@ -686,7 +751,8 @@ def load_model(
686
  if cfg.adapter == "lora" and loftq_bits:
687
  skip_prepare_model_for_kbit_training = True
688
 
689
- if qlora_fsdp:
 
690
  skip_prepare_model_for_kbit_training = True
691
 
692
  if cfg.adapter in ["lora", "qlora"]:
@@ -727,7 +793,7 @@ def load_model(
727
  cfg.ddp
728
  and not load_in_8bit
729
  and not (cfg.rl and cfg.load_in_4bit)
730
- and not qlora_fsdp
731
  ):
732
  # TODO revaldate this conditional
733
  model.to(f"cuda:{cfg.local_rank}")
@@ -883,7 +949,12 @@ def load_lora(model, cfg, inference=False, config_only=False):
883
 
884
  rank = int(os.environ.get("LOCAL_RANK", 0))
885
 
886
- if cfg.fsdp and cfg.adapter == "qlora" and rank != 0:
 
 
 
 
 
887
  setup_quantized_meta_for_peft(model)
888
 
889
  if cfg.lora_model_dir:
@@ -908,7 +979,12 @@ def load_lora(model, cfg, inference=False, config_only=False):
908
  LOG.warning(
909
  "Exception caught during model.print_trainable_parameters(): %s", exc
910
  )
911
- elif cfg.fsdp and cfg.adapter == "qlora":
 
 
 
 
 
912
  setup_quantized_peft_meta_for_training(model)
913
 
914
  return model, lora_config
 
45
  from axolotl.utils.dict import DictDefault
46
  from axolotl.utils.distributed import zero_only
47
  from axolotl.utils.lora_embeddings import get_linear_embedding_layers
48
+ from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
49
 
50
  LOG = logging.getLogger("axolotl")
51
 
52
 
53
+ # copied from accelerator.FullyShardedDataParallelPlugin
54
+ def get_module_class_from_name(module, name):
55
+ """
56
+ Gets a class from a module by its name.
57
+
58
+ Args:
59
+ module (`torch.nn.Module`): The module to get the class from.
60
+ name (`str`): The name of the class.
61
+ """
62
+ modules_children = list(module.children())
63
+ if module.__class__.__name__ == name:
64
+ return module.__class__
65
+
66
+ if len(modules_children) == 0:
67
+ return None
68
+
69
+ for child_module in modules_children:
70
+ module_class = get_module_class_from_name(child_module, name)
71
+ if module_class is not None:
72
+ return module_class
73
+
74
+ return None
75
+
76
+
77
  def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
78
  quant_config_exists = (
79
  hasattr(model_config, "quantization_config")
 
484
  "bnb_4bit_quant_type": "nf4",
485
  "bnb_4bit_quant_storage": torch.bfloat16,
486
  }
487
+ if cfg.model_config_type in ["jamba", "qwen2_moe"] and not cfg.deepspeed:
488
  # for some reason, this causes the loss to be off by an order of magnitude
489
  # but deepspeed needs this still in bfloat16
490
  bnb_config["bnb_4bit_quant_storage"] = torch.float32
 
495
  model_kwargs["quantization_config"] = BitsAndBytesConfig(
496
  **bnb_config,
497
  )
498
+ elif cfg.adapter == "lora" and cfg.load_in_8bit:
499
+ bnb_config = {
500
+ "load_in_8bit": True,
501
+ }
502
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
503
+ **bnb_config,
504
+ )
505
 
506
  if cfg.load_in_8bit and cfg.adapter is not None:
507
  model_kwargs["load_in_8bit"] = True
 
549
  qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora"
550
 
551
  try:
552
+ skip_move_to_device = False
553
  if (
554
+ cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
555
+ ) and not qlora_fsdp:
556
+ model = load_sharded_model(
557
+ base_model,
558
+ model_config,
559
+ cfg,
560
+ torch_dtype=cfg.torch_dtype,
561
+ )
562
+ skip_move_to_device = True
563
+ elif (
564
+ qlora_fsdp
565
+ and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
566
+ and cfg.model_config_type == "dbrx"
567
+ ):
568
+ quant_storage = cfg.torch_dtype
569
+ model = load_sharded_model_quant(
570
+ base_model,
571
+ model_config,
572
+ cfg,
573
+ quant_storage=quant_storage,
574
+ )
575
+ skip_move_to_device = True
576
+ elif (
577
  model_config.model_type == "llama"
578
  and not cfg.trust_remote_code
579
  and not cfg.gptq
 
653
  **model_kwargs,
654
  )
655
  else:
656
+ if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
657
+ skip_move_to_device = True
658
+ if "device_map" in model_kwargs:
659
+ del model_kwargs["device_map"]
660
+
661
  model = AutoModelForCausalLM.from_pretrained(
662
  base_model,
663
  config=model_config,
 
731
  needs_fa2_dtype = cfg.adapter or cfg.fsdp
732
  skip_prepare_model_for_kbit_training = False
733
 
734
+ if is_deepspeed_zero3_enabled():
735
  from deepspeed.utils import ( # pylint: disable=no-name-in-module
736
  set_z3_leaf_modules,
737
  )
 
738
 
739
+ if cfg.model_config_type == "mixtral":
740
+ moe_block = get_module_class_from_name(model, "MixtralSparseMoeBlock")
741
+ set_z3_leaf_modules(model, [moe_block])
742
+ elif cfg.model_config_type == "dbrx":
743
+ moe_block = get_module_class_from_name(model, "DbrxFFN")
744
+ set_z3_leaf_modules(model, [moe_block])
745
 
746
  if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
747
  # Qwen doesn't play nicely with LoRA if this is enabled
 
751
  if cfg.adapter == "lora" and loftq_bits:
752
  skip_prepare_model_for_kbit_training = True
753
 
754
+ if qlora_fsdp or (cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading):
755
+ # make sure everything is in the same dtype
756
  skip_prepare_model_for_kbit_training = True
757
 
758
  if cfg.adapter in ["lora", "qlora"]:
 
793
  cfg.ddp
794
  and not load_in_8bit
795
  and not (cfg.rl and cfg.load_in_4bit)
796
+ and not skip_move_to_device
797
  ):
798
  # TODO revaldate this conditional
799
  model.to(f"cuda:{cfg.local_rank}")
 
949
 
950
  rank = int(os.environ.get("LOCAL_RANK", 0))
951
 
952
+ if (
953
+ cfg.fsdp
954
+ and cfg.adapter
955
+ and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
956
+ and rank != 0
957
+ ):
958
  setup_quantized_meta_for_peft(model)
959
 
960
  if cfg.lora_model_dir:
 
979
  LOG.warning(
980
  "Exception caught during model.print_trainable_parameters(): %s", exc
981
  )
982
+ elif (
983
+ cfg.fsdp
984
+ and cfg.adapter
985
+ and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
986
+ and rank != 0
987
+ ):
988
  setup_quantized_peft_meta_for_training(model)
989
 
990
  return model, lora_config
src/axolotl/utils/trainer.py CHANGED
@@ -306,6 +306,8 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
306
 
307
  def setup_fsdp_envs(cfg):
308
  os.environ["ACCELERATE_USE_FSDP"] = "true"
 
 
309
  if cfg.fsdp_config.fsdp_offload_params:
310
  os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
311
  if cfg.fsdp_config.fsdp_sync_module_states:
 
306
 
307
  def setup_fsdp_envs(cfg):
308
  os.environ["ACCELERATE_USE_FSDP"] = "true"
309
+ if cfg.fsdp_config.fsdp_activation_checkpointing:
310
+ os.environ["FSDP_ACTIVATION_CHECKPOINTING"] = "true"
311
  if cfg.fsdp_config.fsdp_offload_params:
312
  os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
313
  if cfg.fsdp_config.fsdp_sync_module_states: