DBRX Model Support (#1462)
Browse files* wip for dbrx finetuning
* add fastcore for parallel loading of sharded weights
* fix dtype for load, use PartialState instead of accelerator to init process group, remove redundant wandb callback
* update to use v2 of the converted model
* more fixes for dbrx loras
* make sure to enable fsdp activation checkpointing
* fix support for 8bit loras too for dbrx
* apply z3 leaf moe fix for DBRX with deepspeed
* don't raise value error since child module searches could fail and be ok
* revert a previous change to fix fsdp
* update mistral/mistral qlora+fsdp yamls
* fix qlora+fsdp quant storage type
* more edge cases for qlora-fsdp
* fixes for fsdp+qlora w optimizer in 8bit
* add bigstral z3 config and make sure to use full_state_dict for fsdp
- deepspeed_configs/zero3_bf16_cpuoffload_all.json +2 -0
- deepspeed_configs/zero3_bf16_cpuoffload_params.json +2 -0
- examples/dbrx/16bit-lora.yaml +81 -0
- examples/dbrx/8bit-lora.yaml +81 -0
- examples/dbrx/README.md +26 -0
- examples/dbrx/fft-ds-zero3.yaml +56 -0
- examples/llama-2/qlora-fsdp.yml +3 -1
- examples/mistral/bigstral-ds-zero3.yaml +63 -0
- examples/mistral/mistral-qlora-fsdp.yml +82 -0
- examples/mistral/mixtral-8x22b-qlora-fsdp.yml +81 -0
- examples/mistral/mixtral-qlora-fsdp.yml +13 -2
- requirements.txt +1 -0
- src/axolotl/core/trainer_builder.py +0 -4
- src/axolotl/train.py +3 -0
- src/axolotl/utils/config/models/input/v0_4_1/__init__.py +11 -3
- src/axolotl/utils/distributed.py +9 -11
- src/axolotl/utils/model_shard_quant.py +259 -0
- src/axolotl/utils/models.py +84 -8
- src/axolotl/utils/trainer.py +2 -0
deepspeed_configs/zero3_bf16_cpuoffload_all.json
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
{
|
|
|
|
|
2 |
"zero_optimization": {
|
3 |
"stage": 3,
|
4 |
"offload_optimizer": {
|
|
|
1 |
{
|
2 |
+
"zero_force_ds_cpu_optimizer": false,
|
3 |
+
"zero_allow_untested_optimizer": true,
|
4 |
"zero_optimization": {
|
5 |
"stage": 3,
|
6 |
"offload_optimizer": {
|
deepspeed_configs/zero3_bf16_cpuoffload_params.json
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
{
|
|
|
|
|
2 |
"zero_optimization": {
|
3 |
"stage": 3,
|
4 |
"offload_param": {
|
|
|
1 |
{
|
2 |
+
"zero_force_ds_cpu_optimizer": false,
|
3 |
+
"zero_allow_untested_optimizer": true,
|
4 |
"zero_optimization": {
|
5 |
"stage": 3,
|
6 |
"offload_param": {
|
examples/dbrx/16bit-lora.yaml
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_model: LnL-AI/dbrx-base-converted-v2
|
2 |
+
trust_remote_code: true
|
3 |
+
|
4 |
+
load_in_8bit: false
|
5 |
+
load_in_4bit: false
|
6 |
+
strict: false
|
7 |
+
|
8 |
+
datasets:
|
9 |
+
- path: tatsu-lab/alpaca
|
10 |
+
type: alpaca
|
11 |
+
dataset_prepared_path: last_run_prepared
|
12 |
+
val_set_size: 0.0
|
13 |
+
output_dir: ./out
|
14 |
+
|
15 |
+
sequence_len: 512
|
16 |
+
sample_packing: false
|
17 |
+
pad_to_sequence_len: false
|
18 |
+
|
19 |
+
wandb_project:
|
20 |
+
wandb_entity:
|
21 |
+
wandb_watch:
|
22 |
+
wandb_name:
|
23 |
+
wandb_log_model:
|
24 |
+
|
25 |
+
adapter: lora
|
26 |
+
lora_model_dir:
|
27 |
+
lora_r: 8
|
28 |
+
lora_alpha: 16
|
29 |
+
lora_dropout: 0.05
|
30 |
+
# w1, w2, & v1 will hang the trainer
|
31 |
+
lora_target_modules:
|
32 |
+
- q_proj # attn
|
33 |
+
- k_proj # attn
|
34 |
+
- v_proj # attn
|
35 |
+
- out_proj # attn
|
36 |
+
- layer # router
|
37 |
+
# - w1
|
38 |
+
# - w2
|
39 |
+
# - v1
|
40 |
+
|
41 |
+
gradient_accumulation_steps: 1
|
42 |
+
micro_batch_size: 1
|
43 |
+
num_epochs: 1
|
44 |
+
optimizer: paged_adamw_8bit
|
45 |
+
lr_scheduler: cosine
|
46 |
+
learning_rate: 0.0002
|
47 |
+
|
48 |
+
train_on_inputs: false
|
49 |
+
group_by_length: false
|
50 |
+
bf16: auto
|
51 |
+
fp16:
|
52 |
+
tf32: false
|
53 |
+
|
54 |
+
gradient_checkpointing: false # don't use with fsdp_activation_checkpointing
|
55 |
+
gradient_checkpointing_kwargs:
|
56 |
+
use_reentrant: false
|
57 |
+
early_stopping_patience:
|
58 |
+
resume_from_checkpoint:
|
59 |
+
local_rank:
|
60 |
+
logging_steps: 1
|
61 |
+
xformers_attention:
|
62 |
+
flash_attention: true
|
63 |
+
|
64 |
+
warmup_steps: 10
|
65 |
+
evals_per_epoch:
|
66 |
+
saves_per_epoch: 1
|
67 |
+
debug:
|
68 |
+
weight_decay: 0.0
|
69 |
+
fsdp:
|
70 |
+
- full_shard
|
71 |
+
- auto_wrap
|
72 |
+
fsdp_config:
|
73 |
+
fsdp_limit_all_gathers: true
|
74 |
+
fsdp_sync_module_states: true
|
75 |
+
fsdp_offload_params: false
|
76 |
+
fsdp_use_orig_params: false
|
77 |
+
fsdp_cpu_ram_efficient_loading: true
|
78 |
+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
79 |
+
fsdp_transformer_layer_cls_to_wrap: DbrxBlock
|
80 |
+
fsdp_state_dict_type: FULL_STATE_DICT
|
81 |
+
fsdp_activation_checkpointing: true
|
examples/dbrx/8bit-lora.yaml
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_model: LnL-AI/dbrx-base-converted-v2
|
2 |
+
trust_remote_code: true
|
3 |
+
|
4 |
+
load_in_8bit: true
|
5 |
+
load_in_4bit: false
|
6 |
+
strict: false
|
7 |
+
|
8 |
+
datasets:
|
9 |
+
- path: tatsu-lab/alpaca
|
10 |
+
type: alpaca
|
11 |
+
dataset_prepared_path: last_run_prepared
|
12 |
+
val_set_size: 0.0
|
13 |
+
output_dir: ./out
|
14 |
+
|
15 |
+
sequence_len: 512
|
16 |
+
sample_packing: false
|
17 |
+
pad_to_sequence_len: false
|
18 |
+
|
19 |
+
wandb_project:
|
20 |
+
wandb_entity:
|
21 |
+
wandb_watch:
|
22 |
+
wandb_name:
|
23 |
+
wandb_log_model:
|
24 |
+
|
25 |
+
adapter: lora
|
26 |
+
lora_model_dir:
|
27 |
+
lora_r: 8
|
28 |
+
lora_alpha: 16
|
29 |
+
lora_dropout: 0.05
|
30 |
+
# w1, w2, & v1 will hang the trainer
|
31 |
+
lora_target_modules:
|
32 |
+
- q_proj # attn
|
33 |
+
- k_proj # attn
|
34 |
+
- v_proj # attn
|
35 |
+
- out_proj # attn
|
36 |
+
- layer # router
|
37 |
+
# - w1
|
38 |
+
# - w2
|
39 |
+
# - v1
|
40 |
+
|
41 |
+
gradient_accumulation_steps: 1
|
42 |
+
micro_batch_size: 1
|
43 |
+
num_epochs: 1
|
44 |
+
optimizer: paged_adamw_8bit
|
45 |
+
lr_scheduler: cosine
|
46 |
+
learning_rate: 0.0002
|
47 |
+
|
48 |
+
train_on_inputs: false
|
49 |
+
group_by_length: false
|
50 |
+
bf16: auto
|
51 |
+
fp16:
|
52 |
+
tf32: false
|
53 |
+
|
54 |
+
gradient_checkpointing: false # don't use with fsdp_activation_checkpointing
|
55 |
+
gradient_checkpointing_kwargs:
|
56 |
+
use_reentrant: false
|
57 |
+
early_stopping_patience:
|
58 |
+
resume_from_checkpoint:
|
59 |
+
local_rank:
|
60 |
+
logging_steps: 1
|
61 |
+
xformers_attention:
|
62 |
+
flash_attention: true
|
63 |
+
|
64 |
+
warmup_steps: 10
|
65 |
+
evals_per_epoch:
|
66 |
+
saves_per_epoch: 1
|
67 |
+
debug:
|
68 |
+
weight_decay: 0.0
|
69 |
+
fsdp:
|
70 |
+
- full_shard
|
71 |
+
- auto_wrap
|
72 |
+
fsdp_config:
|
73 |
+
fsdp_limit_all_gathers: true
|
74 |
+
fsdp_sync_module_states: true
|
75 |
+
fsdp_offload_params: false
|
76 |
+
fsdp_use_orig_params: false
|
77 |
+
fsdp_cpu_ram_efficient_loading: true
|
78 |
+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
79 |
+
fsdp_transformer_layer_cls_to_wrap: DbrxBlock
|
80 |
+
fsdp_state_dict_type: FULL_STATE_DICT
|
81 |
+
fsdp_activation_checkpointing: true
|
examples/dbrx/README.md
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DBRX MoE
|
2 |
+
|
3 |
+
Currently, for LoRA, only the `q_proj`, `k_proj`, `v_proj` `out_proj` and `layer` Linear layers are trainable.
|
4 |
+
|
5 |
+
We are using the "converted" base models based on [this issue](https://huggingface.co/databricks/dbrx-instruct/discussions/10)
|
6 |
+
where the Experts are fused as an `nn.Parameter` rather than a `nn.Linear` layer. However, the implementation
|
7 |
+
is still a bit buggy and attempting to train a LoRA adapter over those `w1`, `w2` and `v1` layers
|
8 |
+
results in the trainer hanging.
|
9 |
+
|
10 |
+
|
11 |
+
### FSDP
|
12 |
+
We've tested using the [`LnL-AI/dbrx-base-converted-v2`](https://huggingface.co/LnL-AI/dbrx-base-converted-v2) model as the base model for FSDP.
|
13 |
+
|
14 |
+
The high memory usage seen w/ FSDP is due to FSDP not supporting 8bit optimizers.
|
15 |
+
|
16 |
+
- 16-bit LoRA w/ FSDP
|
17 |
+
- ✅ w/o CPU Offload - 8x80GB uses ~80GiB/gpu
|
18 |
+
- ❌ w/ CPU Offload - `paged_adamw_8bit` optimizer errors from being on cpu
|
19 |
+
- ✅ 8-bit LoRA w/ FSDP
|
20 |
+
- ❌ 4-bit QLoRA w/ FSDP - errors w/: `Error an illegal memory access was encountered at line 90 in file /src/csrc/ops.cu`
|
21 |
+
- ✅ bf16 full finetune w/ FSDP, freezing all but first 8 layers (8x80GB uses ~78GiB/gpu)
|
22 |
+
|
23 |
+
|
24 |
+
### Deepspeed
|
25 |
+
|
26 |
+
WIP
|
examples/dbrx/fft-ds-zero3.yaml
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_model: LnL-AI/dbrx-base-converted-v2
|
2 |
+
trust_remote_code: true
|
3 |
+
|
4 |
+
load_in_8bit: false
|
5 |
+
load_in_4bit: false
|
6 |
+
strict: false
|
7 |
+
|
8 |
+
datasets:
|
9 |
+
- path: tatsu-lab/alpaca
|
10 |
+
type: alpaca
|
11 |
+
dataset_prepared_path: last_run_prepared
|
12 |
+
val_set_size: 0.0
|
13 |
+
output_dir: ./out
|
14 |
+
|
15 |
+
sequence_len: 512
|
16 |
+
sample_packing: false
|
17 |
+
pad_to_sequence_len: false
|
18 |
+
|
19 |
+
unfrozen_parameters:
|
20 |
+
- transformer.blocks.[0-7].
|
21 |
+
|
22 |
+
wandb_project:
|
23 |
+
wandb_entity:
|
24 |
+
wandb_watch:
|
25 |
+
wandb_name:
|
26 |
+
wandb_log_model:
|
27 |
+
|
28 |
+
gradient_accumulation_steps: 1
|
29 |
+
micro_batch_size: 1
|
30 |
+
num_epochs: 1
|
31 |
+
optimizer: paged_adamw_8bit
|
32 |
+
lr_scheduler: cosine
|
33 |
+
learning_rate: 0.0002
|
34 |
+
|
35 |
+
train_on_inputs: false
|
36 |
+
group_by_length: false
|
37 |
+
bf16: auto
|
38 |
+
fp16:
|
39 |
+
tf32: false
|
40 |
+
|
41 |
+
gradient_checkpointing: true
|
42 |
+
gradient_checkpointing_kwargs:
|
43 |
+
use_reentrant: false
|
44 |
+
early_stopping_patience:
|
45 |
+
resume_from_checkpoint:
|
46 |
+
local_rank:
|
47 |
+
logging_steps: 1
|
48 |
+
xformers_attention:
|
49 |
+
flash_attention: true
|
50 |
+
|
51 |
+
warmup_steps: 10
|
52 |
+
evals_per_epoch:
|
53 |
+
saves_per_epoch: 1
|
54 |
+
debug:
|
55 |
+
weight_decay: 0.0
|
56 |
+
deepspeed: deepspeed_configs/zero3_bf16.json
|
examples/llama-2/qlora-fsdp.yml
CHANGED
@@ -65,12 +65,14 @@ deepspeed:
|
|
65 |
weight_decay: 0.0
|
66 |
fsdp:
|
67 |
- full_shard
|
|
|
68 |
fsdp_config:
|
69 |
fsdp_limit_all_gathers: true
|
70 |
fsdp_sync_module_states: true
|
71 |
fsdp_offload_params: true
|
72 |
fsdp_use_orig_params: false
|
73 |
fsdp_cpu_ram_efficient_loading: true
|
|
|
74 |
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
75 |
-
fsdp_state_dict_type:
|
76 |
special_tokens:
|
|
|
65 |
weight_decay: 0.0
|
66 |
fsdp:
|
67 |
- full_shard
|
68 |
+
- auto_wrap
|
69 |
fsdp_config:
|
70 |
fsdp_limit_all_gathers: true
|
71 |
fsdp_sync_module_states: true
|
72 |
fsdp_offload_params: true
|
73 |
fsdp_use_orig_params: false
|
74 |
fsdp_cpu_ram_efficient_loading: true
|
75 |
+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
76 |
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
77 |
+
fsdp_state_dict_type: FULL_STATE_DICT
|
78 |
special_tokens:
|
examples/mistral/bigstral-ds-zero3.yaml
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_model: mistral-community/Mixtral-8x22B-v0.1
|
2 |
+
model_type: AutoModelForCausalLM
|
3 |
+
tokenizer_type: LlamaTokenizer
|
4 |
+
trust_remote_code: true
|
5 |
+
|
6 |
+
load_in_8bit: false
|
7 |
+
load_in_4bit: false
|
8 |
+
strict: false
|
9 |
+
|
10 |
+
unfrozen_parameters:
|
11 |
+
- ^lm_head.weight$
|
12 |
+
- ^model.embed_tokens.weight$
|
13 |
+
- model.layers.4[4-9]+.block_sparse_moe.gate
|
14 |
+
- model.layers.4[4-9]+.block_sparse_moe.experts
|
15 |
+
- model.layers.5[0-5]+.block_sparse_moe.gate
|
16 |
+
- model.layers.5[0-5]+.block_sparse_moe.experts
|
17 |
+
|
18 |
+
model_config:
|
19 |
+
output_router_logits: true
|
20 |
+
|
21 |
+
datasets:
|
22 |
+
- path: tatsu-lab/alpaca
|
23 |
+
type: alpaca
|
24 |
+
dataset_prepared_path: last_run_prepared
|
25 |
+
val_set_size: 0.05
|
26 |
+
output_dir: ./out
|
27 |
+
|
28 |
+
sequence_len: 2048
|
29 |
+
sample_packing: true
|
30 |
+
pad_to_sequence_len: true
|
31 |
+
|
32 |
+
gradient_accumulation_steps: 1
|
33 |
+
micro_batch_size: 1
|
34 |
+
num_epochs: 3
|
35 |
+
optimizer: adamw_bnb_8bit
|
36 |
+
lr_scheduler: cosine
|
37 |
+
learning_rate: 0.0001
|
38 |
+
|
39 |
+
train_on_inputs: false
|
40 |
+
group_by_length: false
|
41 |
+
bf16: auto
|
42 |
+
fp16:
|
43 |
+
tf32: false
|
44 |
+
|
45 |
+
gradient_checkpointing: true
|
46 |
+
early_stopping_patience:
|
47 |
+
resume_from_checkpoint:
|
48 |
+
local_rank:
|
49 |
+
logging_steps: 1
|
50 |
+
xformers_attention:
|
51 |
+
flash_attention: true
|
52 |
+
|
53 |
+
save_total_limit: 1
|
54 |
+
save_steps:
|
55 |
+
debug:
|
56 |
+
deepspeed: deepspeed_configs/zero3_bf16_cpuoffload_params.json
|
57 |
+
weight_decay: 0.0
|
58 |
+
fsdp:
|
59 |
+
fsdp_config:
|
60 |
+
special_tokens:
|
61 |
+
eos_token: "<|im_end|>"
|
62 |
+
tokens:
|
63 |
+
- "<|im_start|>"
|
examples/mistral/mistral-qlora-fsdp.yml
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_model: mistralai/Mixtral-8x7B-v0.1
|
2 |
+
model_type: AutoModelForCausalLM
|
3 |
+
tokenizer_type: LlamaTokenizer
|
4 |
+
trust_remote_code: true
|
5 |
+
|
6 |
+
load_in_8bit: false
|
7 |
+
load_in_4bit: true
|
8 |
+
strict: false
|
9 |
+
|
10 |
+
datasets:
|
11 |
+
- path: tatsu-lab/alpaca
|
12 |
+
type: alpaca
|
13 |
+
dataset_prepared_path: last_run_prepared
|
14 |
+
val_set_size: 0.02
|
15 |
+
output_dir: ./qlora-out
|
16 |
+
|
17 |
+
model_config:
|
18 |
+
output_router_logits: true
|
19 |
+
|
20 |
+
adapter: qlora
|
21 |
+
lora_model_dir:
|
22 |
+
|
23 |
+
sequence_len: 1024
|
24 |
+
sample_packing: false
|
25 |
+
pad_to_sequence_len: false
|
26 |
+
|
27 |
+
lora_r: 32
|
28 |
+
lora_alpha: 16
|
29 |
+
lora_dropout: 0.05
|
30 |
+
lora_target_linear: true
|
31 |
+
lora_fan_in_fan_out:
|
32 |
+
|
33 |
+
wandb_project:
|
34 |
+
wandb_entity:
|
35 |
+
wandb_watch:
|
36 |
+
wandb_name:
|
37 |
+
wandb_log_model:
|
38 |
+
|
39 |
+
gradient_accumulation_steps: 4
|
40 |
+
micro_batch_size: 2
|
41 |
+
num_epochs: 1
|
42 |
+
optimizer: paged_adamw_8bit
|
43 |
+
lr_scheduler: cosine
|
44 |
+
learning_rate: 0.0002
|
45 |
+
|
46 |
+
train_on_inputs: false
|
47 |
+
group_by_length: false
|
48 |
+
bf16: auto
|
49 |
+
fp16:
|
50 |
+
tf32: false
|
51 |
+
|
52 |
+
gradient_checkpointing: true
|
53 |
+
early_stopping_patience:
|
54 |
+
resume_from_checkpoint:
|
55 |
+
local_rank:
|
56 |
+
logging_steps: 1
|
57 |
+
xformers_attention:
|
58 |
+
flash_attention: true
|
59 |
+
|
60 |
+
loss_watchdog_threshold: 5.0
|
61 |
+
loss_watchdog_patience: 3
|
62 |
+
|
63 |
+
warmup_steps: 10
|
64 |
+
evals_per_epoch: 4
|
65 |
+
eval_table_size:
|
66 |
+
eval_max_new_tokens: 128
|
67 |
+
saves_per_epoch: 1
|
68 |
+
debug:
|
69 |
+
weight_decay: 0.0
|
70 |
+
fsdp:
|
71 |
+
- full_shard
|
72 |
+
- auto_wrap
|
73 |
+
fsdp_config:
|
74 |
+
fsdp_limit_all_gathers: true
|
75 |
+
fsdp_sync_module_states: true
|
76 |
+
fsdp_offload_params: false
|
77 |
+
fsdp_use_orig_params: false
|
78 |
+
fsdp_cpu_ram_efficient_loading: false
|
79 |
+
fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
|
80 |
+
fsdp_state_dict_type: FULL_STATE_DICT
|
81 |
+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
82 |
+
special_tokens:
|
examples/mistral/mixtral-8x22b-qlora-fsdp.yml
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_model: mistral-community/Mixtral-8x22B-v0.1
|
2 |
+
model_type: AutoModelForCausalLM
|
3 |
+
tokenizer_type: LlamaTokenizer
|
4 |
+
|
5 |
+
load_in_8bit: false
|
6 |
+
load_in_4bit: true
|
7 |
+
strict: false
|
8 |
+
|
9 |
+
datasets:
|
10 |
+
- path: tatsu-lab/alpaca
|
11 |
+
type: alpaca
|
12 |
+
dataset_prepared_path: last_run_prepared
|
13 |
+
val_set_size: 0.02
|
14 |
+
output_dir: ./qlora-out
|
15 |
+
|
16 |
+
model_config:
|
17 |
+
output_router_logits: true
|
18 |
+
|
19 |
+
adapter: qlora
|
20 |
+
lora_model_dir:
|
21 |
+
|
22 |
+
sequence_len: 1024
|
23 |
+
sample_packing: false
|
24 |
+
pad_to_sequence_len: false
|
25 |
+
|
26 |
+
lora_r: 32
|
27 |
+
lora_alpha: 16
|
28 |
+
lora_dropout: 0.05
|
29 |
+
lora_target_linear: true
|
30 |
+
lora_fan_in_fan_out:
|
31 |
+
|
32 |
+
wandb_project:
|
33 |
+
wandb_entity:
|
34 |
+
wandb_watch:
|
35 |
+
wandb_name:
|
36 |
+
wandb_log_model:
|
37 |
+
|
38 |
+
gradient_accumulation_steps: 4
|
39 |
+
micro_batch_size: 2
|
40 |
+
num_epochs: 1
|
41 |
+
optimizer: adamw_torch
|
42 |
+
lr_scheduler: cosine
|
43 |
+
learning_rate: 0.0002
|
44 |
+
|
45 |
+
train_on_inputs: false
|
46 |
+
group_by_length: false
|
47 |
+
bf16: auto
|
48 |
+
fp16:
|
49 |
+
tf32: true
|
50 |
+
|
51 |
+
gradient_checkpointing: true
|
52 |
+
early_stopping_patience:
|
53 |
+
resume_from_checkpoint:
|
54 |
+
local_rank:
|
55 |
+
logging_steps: 1
|
56 |
+
xformers_attention:
|
57 |
+
flash_attention: true
|
58 |
+
|
59 |
+
loss_watchdog_threshold: 5.0
|
60 |
+
loss_watchdog_patience: 3
|
61 |
+
|
62 |
+
warmup_steps: 10
|
63 |
+
evals_per_epoch: 4
|
64 |
+
eval_table_size:
|
65 |
+
eval_max_new_tokens: 128
|
66 |
+
saves_per_epoch: 1
|
67 |
+
debug:
|
68 |
+
weight_decay: 0.0
|
69 |
+
fsdp:
|
70 |
+
- full_shard
|
71 |
+
- auto_wrap
|
72 |
+
fsdp_config:
|
73 |
+
fsdp_limit_all_gathers: true
|
74 |
+
fsdp_sync_module_states: true
|
75 |
+
fsdp_offload_params: true
|
76 |
+
fsdp_use_orig_params: false
|
77 |
+
fsdp_cpu_ram_efficient_loading: true
|
78 |
+
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
|
79 |
+
fsdp_state_dict_type: FULL_STATE_DICT
|
80 |
+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
81 |
+
special_tokens:
|
examples/mistral/mixtral-qlora-fsdp.yml
CHANGED
@@ -39,7 +39,7 @@ wandb_log_model:
|
|
39 |
gradient_accumulation_steps: 4
|
40 |
micro_batch_size: 2
|
41 |
num_epochs: 1
|
42 |
-
optimizer:
|
43 |
lr_scheduler: cosine
|
44 |
learning_rate: 0.0002
|
45 |
|
@@ -47,7 +47,7 @@ train_on_inputs: false
|
|
47 |
group_by_length: false
|
48 |
bf16: auto
|
49 |
fp16:
|
50 |
-
tf32:
|
51 |
|
52 |
gradient_checkpointing: true
|
53 |
early_stopping_patience:
|
@@ -69,6 +69,17 @@ debug:
|
|
69 |
weight_decay: 0.0
|
70 |
fsdp:
|
71 |
- full_shard
|
|
|
72 |
fsdp_config:
|
|
|
|
|
|
|
|
|
|
|
73 |
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
|
|
|
|
|
|
|
|
|
|
|
74 |
special_tokens:
|
|
|
39 |
gradient_accumulation_steps: 4
|
40 |
micro_batch_size: 2
|
41 |
num_epochs: 1
|
42 |
+
optimizer: adamw_torch
|
43 |
lr_scheduler: cosine
|
44 |
learning_rate: 0.0002
|
45 |
|
|
|
47 |
group_by_length: false
|
48 |
bf16: auto
|
49 |
fp16:
|
50 |
+
tf32: true
|
51 |
|
52 |
gradient_checkpointing: true
|
53 |
early_stopping_patience:
|
|
|
69 |
weight_decay: 0.0
|
70 |
fsdp:
|
71 |
- full_shard
|
72 |
+
- auto_wrap
|
73 |
fsdp_config:
|
74 |
+
fsdp_limit_all_gathers: true
|
75 |
+
fsdp_sync_module_states: true
|
76 |
+
fsdp_offload_params: true
|
77 |
+
fsdp_use_orig_params: false
|
78 |
+
fsdp_cpu_ram_efficient_loading: true
|
79 |
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
|
80 |
+
fsdp_state_dict_type: FULL_STATE_DICT
|
81 |
+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
82 |
+
fsdp_sharding_strategy: FULL_SHARD
|
83 |
+
fsdp_forward_prefetch: false
|
84 |
+
fsdp_backward_prefetch: BACKWARD_PRE
|
85 |
special_tokens:
|
requirements.txt
CHANGED
@@ -41,3 +41,4 @@ gcsfs
|
|
41 |
|
42 |
trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
|
43 |
zstandard==0.22.0
|
|
|
|
41 |
|
42 |
trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
|
43 |
zstandard==0.22.0
|
44 |
+
fastcore
|
src/axolotl/core/trainer_builder.py
CHANGED
@@ -918,10 +918,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
918 |
):
|
919 |
callbacks.append(SaveBetterTransformerModelCallback())
|
920 |
|
921 |
-
if self.cfg.use_wandb:
|
922 |
-
callbacks.append(
|
923 |
-
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
924 |
-
)
|
925 |
if self.cfg.use_mlflow and is_mlflow_available():
|
926 |
from axolotl.utils.callbacks.mlflow_ import (
|
927 |
SaveAxolotlConfigtoMlflowCallback,
|
|
|
918 |
):
|
919 |
callbacks.append(SaveBetterTransformerModelCallback())
|
920 |
|
|
|
|
|
|
|
|
|
921 |
if self.cfg.use_mlflow and is_mlflow_available():
|
922 |
from axolotl.utils.callbacks.mlflow_ import (
|
923 |
SaveAxolotlConfigtoMlflowCallback,
|
src/axolotl/train.py
CHANGED
@@ -9,6 +9,7 @@ from typing import Optional, Tuple, Union
|
|
9 |
|
10 |
import torch
|
11 |
import transformers.modelcard
|
|
|
12 |
from accelerate.logging import get_logger
|
13 |
from datasets import Dataset
|
14 |
from peft import PeftModel
|
@@ -81,6 +82,8 @@ def train(
|
|
81 |
if cfg.adapter:
|
82 |
msg += " and peft_config..."
|
83 |
LOG.debug(msg)
|
|
|
|
|
84 |
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
85 |
model.generation_config.do_sample = True
|
86 |
|
|
|
9 |
|
10 |
import torch
|
11 |
import transformers.modelcard
|
12 |
+
from accelerate import Accelerator
|
13 |
from accelerate.logging import get_logger
|
14 |
from datasets import Dataset
|
15 |
from peft import PeftModel
|
|
|
82 |
if cfg.adapter:
|
83 |
msg += " and peft_config..."
|
84 |
LOG.debug(msg)
|
85 |
+
# we wait unitl the last possible moment to setup Accelerator
|
86 |
+
Accelerator()
|
87 |
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
88 |
model.generation_config.do_sample = True
|
89 |
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
@@ -259,6 +259,7 @@ class ModelInputConfig(BaseModel):
|
|
259 |
|
260 |
base_model: str
|
261 |
base_model_config: Optional[str] = None
|
|
|
262 |
tokenizer_config: Optional[str] = None
|
263 |
tokenizer_use_fast: Optional[bool] = None
|
264 |
tokenizer_legacy: Optional[bool] = None
|
@@ -971,9 +972,16 @@ class AxolotlInputConfig(
|
|
971 |
|
972 |
@model_validator(mode="before")
|
973 |
@classmethod
|
974 |
-
def
|
975 |
-
if
|
976 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
977 |
return data
|
978 |
|
979 |
@model_validator(mode="before")
|
|
|
259 |
|
260 |
base_model: str
|
261 |
base_model_config: Optional[str] = None
|
262 |
+
cls_model_config: Optional[str] = None
|
263 |
tokenizer_config: Optional[str] = None
|
264 |
tokenizer_use_fast: Optional[bool] = None
|
265 |
tokenizer_legacy: Optional[bool] = None
|
|
|
972 |
|
973 |
@model_validator(mode="before")
|
974 |
@classmethod
|
975 |
+
def check_fsdp_offload_w_8bit_optimizer(cls, data):
|
976 |
+
if (
|
977 |
+
data.get("fsdp")
|
978 |
+
and "8bit" in data.get("optimizer", "")
|
979 |
+
and data.get("fsdp_config")
|
980 |
+
and data["fsdp_config"].get("fsdp_offload_params")
|
981 |
+
):
|
982 |
+
raise ValueError(
|
983 |
+
f"FSDP Offload not compatible with {data.get('optimizer')}"
|
984 |
+
)
|
985 |
return data
|
986 |
|
987 |
@model_validator(mode="before")
|
src/axolotl/utils/distributed.py
CHANGED
@@ -4,27 +4,25 @@ utility helpers for distributed checks
|
|
4 |
import os
|
5 |
import pickle # nosec
|
6 |
from contextlib import contextmanager
|
|
|
7 |
|
8 |
import torch
|
9 |
import torch.distributed as dist
|
10 |
-
from accelerate import
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
def load_accelerate():
|
16 |
-
global accelerate # pylint: disable=global-statement
|
17 |
-
accelerate = Accelerator()
|
18 |
|
19 |
|
20 |
def is_distributed():
|
21 |
"""
|
22 |
Check if distributed training is initialized.
|
23 |
"""
|
24 |
-
global
|
25 |
-
if not
|
26 |
-
|
27 |
-
|
|
|
|
|
28 |
|
29 |
|
30 |
def barrier():
|
|
|
4 |
import os
|
5 |
import pickle # nosec
|
6 |
from contextlib import contextmanager
|
7 |
+
from datetime import timedelta
|
8 |
|
9 |
import torch
|
10 |
import torch.distributed as dist
|
11 |
+
from accelerate import PartialState
|
12 |
|
13 |
+
distributed_state = None # pylint: disable=invalid-name
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
|
16 |
def is_distributed():
|
17 |
"""
|
18 |
Check if distributed training is initialized.
|
19 |
"""
|
20 |
+
global distributed_state # pylint: disable=global-statement
|
21 |
+
if not distributed_state:
|
22 |
+
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
|
23 |
+
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
|
24 |
+
|
25 |
+
return distributed_state.use_distributed and distributed_state.initialized
|
26 |
|
27 |
|
28 |
def barrier():
|
src/axolotl/utils/model_shard_quant.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
module to handle loading model on cpu/meta device for FSDP
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
from typing import List, Optional, Type, Union
|
7 |
+
|
8 |
+
import safetensors
|
9 |
+
import torch
|
10 |
+
from accelerate import init_empty_weights
|
11 |
+
from bitsandbytes.nn import Linear4bit, Params4bit
|
12 |
+
from fastcore.parallel import parallel
|
13 |
+
from torch import Tensor, nn
|
14 |
+
from tqdm import tqdm
|
15 |
+
from transformers import AutoModelForCausalLM
|
16 |
+
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
|
17 |
+
|
18 |
+
|
19 |
+
def _replace_linear(
|
20 |
+
model: nn.Module,
|
21 |
+
linear_replacement: Type[nn.Module],
|
22 |
+
quant_config: Union[dict, None] = None,
|
23 |
+
skip_modules=None,
|
24 |
+
**kwargs,
|
25 |
+
):
|
26 |
+
"""
|
27 |
+
Replace linear modules with a new Linear module.
|
28 |
+
Parameters:
|
29 |
+
model (`torch.nn.Module`):
|
30 |
+
Input model or `torch.nn.Module` as the function is run recursively.
|
31 |
+
linear_replacement (`torch.nn.Module`):
|
32 |
+
The linear module that replaces the old one. Only expects standard arguments.
|
33 |
+
If other arguments need to be passed, use a lambda.
|
34 |
+
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
|
35 |
+
List of modules names not to convert. Defaults to `lm_head`.
|
36 |
+
"""
|
37 |
+
if skip_modules is None:
|
38 |
+
skip_modules = ["lm_head"]
|
39 |
+
for name, module in model.named_children():
|
40 |
+
if len(list(module.children())) > 0:
|
41 |
+
_replace_linear(
|
42 |
+
module, linear_replacement, quant_config, skip_modules, **kwargs
|
43 |
+
)
|
44 |
+
|
45 |
+
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
|
46 |
+
if issubclass(linear_replacement, Linear4bit):
|
47 |
+
model._modules[ # pylint: disable=protected-access
|
48 |
+
name
|
49 |
+
] = linear_replacement(
|
50 |
+
module.in_features,
|
51 |
+
module.out_features,
|
52 |
+
module.bias is not None,
|
53 |
+
**kwargs,
|
54 |
+
)
|
55 |
+
else:
|
56 |
+
raise ValueError(
|
57 |
+
f"Unsupported linear replacement: {type(linear_replacement)}"
|
58 |
+
)
|
59 |
+
return model
|
60 |
+
|
61 |
+
|
62 |
+
def load_and_quantize(
|
63 |
+
module: nn.Module,
|
64 |
+
name: str,
|
65 |
+
value: Tensor,
|
66 |
+
device: torch.device = None,
|
67 |
+
dtype: torch.dtype = None,
|
68 |
+
skip_names: Optional[List[str]] = None,
|
69 |
+
to_cpu: bool = False,
|
70 |
+
to_meta: bool = False,
|
71 |
+
verbose: bool = False,
|
72 |
+
quant_method: str = "bnb",
|
73 |
+
):
|
74 |
+
"""
|
75 |
+
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
|
76 |
+
|
77 |
+
Quantizes `Params4bit` on `device` then places on "cpu" if to_cpu=True or "meta" if to_meta=True.
|
78 |
+
"""
|
79 |
+
|
80 |
+
if not skip_names:
|
81 |
+
skip_names = []
|
82 |
+
|
83 |
+
def place_on_device(value):
|
84 |
+
if to_meta:
|
85 |
+
device = "meta"
|
86 |
+
elif to_cpu:
|
87 |
+
device = "cpu"
|
88 |
+
return value.to(device=device, dtype=dtype)
|
89 |
+
|
90 |
+
if any(skip_name in name for skip_name in skip_names):
|
91 |
+
if verbose:
|
92 |
+
print(f"Skipping {name} because it is in skip_names")
|
93 |
+
return
|
94 |
+
|
95 |
+
module_key, _, value_key = name.rpartition(".")
|
96 |
+
try:
|
97 |
+
submodule = module.get_submodule(module_key)
|
98 |
+
except AttributeError as exc:
|
99 |
+
print(f"Module {module_key} not found:\n{exc}")
|
100 |
+
return
|
101 |
+
|
102 |
+
try:
|
103 |
+
if quant_method == "bnb":
|
104 |
+
param = submodule.get_parameter(value_key)
|
105 |
+
if isinstance(param, Params4bit):
|
106 |
+
# With `sync_module_states=True`, a meta device Params4bit needs to be the same
|
107 |
+
# shape as the quantized Params4bit with an initialized quant_state. However,
|
108 |
+
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
|
109 |
+
# workaround quantizes Params4bit to initialize quant_state on all ranks, then
|
110 |
+
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
|
111 |
+
value = type(param)(
|
112 |
+
value.to(device=device, dtype=dtype).data, **param.__dict__
|
113 |
+
).cuda(device)
|
114 |
+
if to_meta:
|
115 |
+
value = type(param)(value.data.to("meta"), **value.__dict__)
|
116 |
+
elif to_cpu:
|
117 |
+
value = type(param)(value.data.to("cpu"), **value.__dict__)
|
118 |
+
else:
|
119 |
+
value = type(param)(place_on_device(value).data)
|
120 |
+
|
121 |
+
except AttributeError:
|
122 |
+
# it's a buffer
|
123 |
+
value = place_on_device(value)
|
124 |
+
|
125 |
+
setattr(submodule, value_key, value)
|
126 |
+
|
127 |
+
|
128 |
+
def n_loading_workers(quant_method: str, param_count: float):
|
129 |
+
devprops = torch.cuda.get_device_properties(torch.cuda.current_device())
|
130 |
+
left = int(os.cpu_count() / torch.cuda.device_count())
|
131 |
+
model_params_b = 70
|
132 |
+
right = int(
|
133 |
+
(4 if quant_method == "hqq" else 8)
|
134 |
+
* (devprops.total_memory / 1e9 / 40)
|
135 |
+
* (model_params_b / (param_count / 1e9))
|
136 |
+
)
|
137 |
+
return min(left, right)
|
138 |
+
|
139 |
+
|
140 |
+
def load_sharded_model(
|
141 |
+
model_name,
|
142 |
+
model_config,
|
143 |
+
cfg,
|
144 |
+
torch_dtype=torch.bfloat16,
|
145 |
+
low_memory=True,
|
146 |
+
):
|
147 |
+
if (low_memory and cfg.local_rank == 0) or not low_memory:
|
148 |
+
model = AutoModelForCausalLM.from_pretrained(
|
149 |
+
model_name,
|
150 |
+
use_cache=False,
|
151 |
+
torch_dtype=torch.float32,
|
152 |
+
_attn_implementation=model_config._attn_implementation, # pylint: disable=protected-access
|
153 |
+
trust_remote_code=cfg.trust_remote_code,
|
154 |
+
)
|
155 |
+
dtype = torch_dtype if not cfg.float32 else None
|
156 |
+
model.to(dtype=dtype, device="cpu" if low_memory else cfg.local_rank)
|
157 |
+
else:
|
158 |
+
with init_empty_weights():
|
159 |
+
model = AutoModelForCausalLM.from_config(
|
160 |
+
model_config,
|
161 |
+
torch_dtype=torch_dtype,
|
162 |
+
trust_remote_code=cfg.trust_remote_code,
|
163 |
+
)
|
164 |
+
return model
|
165 |
+
|
166 |
+
|
167 |
+
def load_sharded_model_quant(
|
168 |
+
model_name,
|
169 |
+
model_config,
|
170 |
+
cfg,
|
171 |
+
compute_dtype=torch.bfloat16,
|
172 |
+
quant_storage=torch.float32,
|
173 |
+
low_memory=True,
|
174 |
+
verbose=False,
|
175 |
+
loading_workers=2,
|
176 |
+
):
|
177 |
+
with init_empty_weights():
|
178 |
+
model = AutoModelForCausalLM.from_config(
|
179 |
+
model_config,
|
180 |
+
trust_remote_code=cfg.trust_remote_code,
|
181 |
+
)
|
182 |
+
if hasattr(model, "transformer"):
|
183 |
+
model.transformer = _replace_linear(
|
184 |
+
model.transformer,
|
185 |
+
Linear4bit,
|
186 |
+
compute_dtype=compute_dtype,
|
187 |
+
quant_type="nf4",
|
188 |
+
quant_storage=quant_storage,
|
189 |
+
)
|
190 |
+
else:
|
191 |
+
# this is the more common case with HF transformers
|
192 |
+
model.model = _replace_linear(
|
193 |
+
model.model,
|
194 |
+
Linear4bit,
|
195 |
+
compute_dtype=compute_dtype,
|
196 |
+
quant_type="nf4",
|
197 |
+
quant_storage=quant_storage,
|
198 |
+
)
|
199 |
+
model.is_loaded_in_4bit = True
|
200 |
+
|
201 |
+
# Grab the safetensors files that hold the weights
|
202 |
+
try:
|
203 |
+
idx = hub.cached_file(model_name, SAFE_WEIGHTS_INDEX_NAME)
|
204 |
+
files, _ = hub.get_checkpoint_shard_files(model_name, idx)
|
205 |
+
except OSError:
|
206 |
+
try:
|
207 |
+
# This means the model doesn't have a model.safetensors.index.json because it is not sharded
|
208 |
+
files = []
|
209 |
+
files.append(hub.cached_file(model_name, SAFE_WEIGHTS_NAME))
|
210 |
+
except OSError as exc:
|
211 |
+
# This means the model probably doesn't have a safetensors file
|
212 |
+
raise exc
|
213 |
+
|
214 |
+
# Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly
|
215 |
+
# and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage
|
216 |
+
def load_and_quantize_parallel(name_param, model, **kwargs):
|
217 |
+
name, param = name_param
|
218 |
+
load_and_quantize(model, name, param, **kwargs)
|
219 |
+
|
220 |
+
quant_method = "bnb"
|
221 |
+
param_count = sum((p.numel() for n, p in model.named_parameters()))
|
222 |
+
|
223 |
+
n_workers = (
|
224 |
+
n_loading_workers(quant_method, param_count)
|
225 |
+
if loading_workers == -1
|
226 |
+
else loading_workers
|
227 |
+
)
|
228 |
+
if cfg.local_rank == 0 and verbose:
|
229 |
+
print(f"Using n_workers: {n_workers} for loading")
|
230 |
+
|
231 |
+
start = time.time()
|
232 |
+
for filename in tqdm(
|
233 |
+
files,
|
234 |
+
desc="Loading & Quantizing Model Shards",
|
235 |
+
disable=cfg.local_rank != 0,
|
236 |
+
position=0,
|
237 |
+
):
|
238 |
+
weights = safetensors.torch.load_file(filename)
|
239 |
+
parallel(
|
240 |
+
load_and_quantize_parallel,
|
241 |
+
iter(weights.items()),
|
242 |
+
n_workers=n_workers,
|
243 |
+
threadpool=True,
|
244 |
+
model=model,
|
245 |
+
dtype=quant_storage,
|
246 |
+
device=cfg.local_rank,
|
247 |
+
skip_names=[],
|
248 |
+
to_cpu=(low_memory and cfg.local_rank == 0),
|
249 |
+
to_meta=(low_memory and cfg.local_rank != 0),
|
250 |
+
verbose=verbose,
|
251 |
+
quant_method=quant_method,
|
252 |
+
)
|
253 |
+
|
254 |
+
if cfg.local_rank == 0 and verbose:
|
255 |
+
print(f"Loaded model weights in {time.time()-start:.3f} seconds")
|
256 |
+
# cleanup any extra memory usage from parallel loading
|
257 |
+
torch.cuda.empty_cache()
|
258 |
+
|
259 |
+
return model
|
src/axolotl/utils/models.py
CHANGED
@@ -45,10 +45,35 @@ from axolotl.utils.chat_templates import chat_templates
|
|
45 |
from axolotl.utils.dict import DictDefault
|
46 |
from axolotl.utils.distributed import zero_only
|
47 |
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
|
|
48 |
|
49 |
LOG = logging.getLogger("axolotl")
|
50 |
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
|
53 |
quant_config_exists = (
|
54 |
hasattr(model_config, "quantization_config")
|
@@ -459,7 +484,7 @@ def load_model(
|
|
459 |
"bnb_4bit_quant_type": "nf4",
|
460 |
"bnb_4bit_quant_storage": torch.bfloat16,
|
461 |
}
|
462 |
-
if not cfg.deepspeed:
|
463 |
# for some reason, this causes the loss to be off by an order of magnitude
|
464 |
# but deepspeed needs this still in bfloat16
|
465 |
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
@@ -470,6 +495,13 @@ def load_model(
|
|
470 |
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
471 |
**bnb_config,
|
472 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
473 |
|
474 |
if cfg.load_in_8bit and cfg.adapter is not None:
|
475 |
model_kwargs["load_in_8bit"] = True
|
@@ -517,7 +549,31 @@ def load_model(
|
|
517 |
qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora"
|
518 |
|
519 |
try:
|
|
|
520 |
if (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
521 |
model_config.model_type == "llama"
|
522 |
and not cfg.trust_remote_code
|
523 |
and not cfg.gptq
|
@@ -597,6 +653,11 @@ def load_model(
|
|
597 |
**model_kwargs,
|
598 |
)
|
599 |
else:
|
|
|
|
|
|
|
|
|
|
|
600 |
model = AutoModelForCausalLM.from_pretrained(
|
601 |
base_model,
|
602 |
config=model_config,
|
@@ -670,13 +731,17 @@ def load_model(
|
|
670 |
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
671 |
skip_prepare_model_for_kbit_training = False
|
672 |
|
673 |
-
if
|
674 |
from deepspeed.utils import ( # pylint: disable=no-name-in-module
|
675 |
set_z3_leaf_modules,
|
676 |
)
|
677 |
-
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
678 |
|
679 |
-
|
|
|
|
|
|
|
|
|
|
|
680 |
|
681 |
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
682 |
# Qwen doesn't play nicely with LoRA if this is enabled
|
@@ -686,7 +751,8 @@ def load_model(
|
|
686 |
if cfg.adapter == "lora" and loftq_bits:
|
687 |
skip_prepare_model_for_kbit_training = True
|
688 |
|
689 |
-
if qlora_fsdp:
|
|
|
690 |
skip_prepare_model_for_kbit_training = True
|
691 |
|
692 |
if cfg.adapter in ["lora", "qlora"]:
|
@@ -727,7 +793,7 @@ def load_model(
|
|
727 |
cfg.ddp
|
728 |
and not load_in_8bit
|
729 |
and not (cfg.rl and cfg.load_in_4bit)
|
730 |
-
and not
|
731 |
):
|
732 |
# TODO revaldate this conditional
|
733 |
model.to(f"cuda:{cfg.local_rank}")
|
@@ -883,7 +949,12 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
|
883 |
|
884 |
rank = int(os.environ.get("LOCAL_RANK", 0))
|
885 |
|
886 |
-
if
|
|
|
|
|
|
|
|
|
|
|
887 |
setup_quantized_meta_for_peft(model)
|
888 |
|
889 |
if cfg.lora_model_dir:
|
@@ -908,7 +979,12 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
|
908 |
LOG.warning(
|
909 |
"Exception caught during model.print_trainable_parameters(): %s", exc
|
910 |
)
|
911 |
-
elif
|
|
|
|
|
|
|
|
|
|
|
912 |
setup_quantized_peft_meta_for_training(model)
|
913 |
|
914 |
return model, lora_config
|
|
|
45 |
from axolotl.utils.dict import DictDefault
|
46 |
from axolotl.utils.distributed import zero_only
|
47 |
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
48 |
+
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
49 |
|
50 |
LOG = logging.getLogger("axolotl")
|
51 |
|
52 |
|
53 |
+
# copied from accelerator.FullyShardedDataParallelPlugin
|
54 |
+
def get_module_class_from_name(module, name):
|
55 |
+
"""
|
56 |
+
Gets a class from a module by its name.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
module (`torch.nn.Module`): The module to get the class from.
|
60 |
+
name (`str`): The name of the class.
|
61 |
+
"""
|
62 |
+
modules_children = list(module.children())
|
63 |
+
if module.__class__.__name__ == name:
|
64 |
+
return module.__class__
|
65 |
+
|
66 |
+
if len(modules_children) == 0:
|
67 |
+
return None
|
68 |
+
|
69 |
+
for child_module in modules_children:
|
70 |
+
module_class = get_module_class_from_name(child_module, name)
|
71 |
+
if module_class is not None:
|
72 |
+
return module_class
|
73 |
+
|
74 |
+
return None
|
75 |
+
|
76 |
+
|
77 |
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
|
78 |
quant_config_exists = (
|
79 |
hasattr(model_config, "quantization_config")
|
|
|
484 |
"bnb_4bit_quant_type": "nf4",
|
485 |
"bnb_4bit_quant_storage": torch.bfloat16,
|
486 |
}
|
487 |
+
if cfg.model_config_type in ["jamba", "qwen2_moe"] and not cfg.deepspeed:
|
488 |
# for some reason, this causes the loss to be off by an order of magnitude
|
489 |
# but deepspeed needs this still in bfloat16
|
490 |
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
|
|
495 |
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
496 |
**bnb_config,
|
497 |
)
|
498 |
+
elif cfg.adapter == "lora" and cfg.load_in_8bit:
|
499 |
+
bnb_config = {
|
500 |
+
"load_in_8bit": True,
|
501 |
+
}
|
502 |
+
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
503 |
+
**bnb_config,
|
504 |
+
)
|
505 |
|
506 |
if cfg.load_in_8bit and cfg.adapter is not None:
|
507 |
model_kwargs["load_in_8bit"] = True
|
|
|
549 |
qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora"
|
550 |
|
551 |
try:
|
552 |
+
skip_move_to_device = False
|
553 |
if (
|
554 |
+
cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
555 |
+
) and not qlora_fsdp:
|
556 |
+
model = load_sharded_model(
|
557 |
+
base_model,
|
558 |
+
model_config,
|
559 |
+
cfg,
|
560 |
+
torch_dtype=cfg.torch_dtype,
|
561 |
+
)
|
562 |
+
skip_move_to_device = True
|
563 |
+
elif (
|
564 |
+
qlora_fsdp
|
565 |
+
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
566 |
+
and cfg.model_config_type == "dbrx"
|
567 |
+
):
|
568 |
+
quant_storage = cfg.torch_dtype
|
569 |
+
model = load_sharded_model_quant(
|
570 |
+
base_model,
|
571 |
+
model_config,
|
572 |
+
cfg,
|
573 |
+
quant_storage=quant_storage,
|
574 |
+
)
|
575 |
+
skip_move_to_device = True
|
576 |
+
elif (
|
577 |
model_config.model_type == "llama"
|
578 |
and not cfg.trust_remote_code
|
579 |
and not cfg.gptq
|
|
|
653 |
**model_kwargs,
|
654 |
)
|
655 |
else:
|
656 |
+
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
657 |
+
skip_move_to_device = True
|
658 |
+
if "device_map" in model_kwargs:
|
659 |
+
del model_kwargs["device_map"]
|
660 |
+
|
661 |
model = AutoModelForCausalLM.from_pretrained(
|
662 |
base_model,
|
663 |
config=model_config,
|
|
|
731 |
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
732 |
skip_prepare_model_for_kbit_training = False
|
733 |
|
734 |
+
if is_deepspeed_zero3_enabled():
|
735 |
from deepspeed.utils import ( # pylint: disable=no-name-in-module
|
736 |
set_z3_leaf_modules,
|
737 |
)
|
|
|
738 |
|
739 |
+
if cfg.model_config_type == "mixtral":
|
740 |
+
moe_block = get_module_class_from_name(model, "MixtralSparseMoeBlock")
|
741 |
+
set_z3_leaf_modules(model, [moe_block])
|
742 |
+
elif cfg.model_config_type == "dbrx":
|
743 |
+
moe_block = get_module_class_from_name(model, "DbrxFFN")
|
744 |
+
set_z3_leaf_modules(model, [moe_block])
|
745 |
|
746 |
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
747 |
# Qwen doesn't play nicely with LoRA if this is enabled
|
|
|
751 |
if cfg.adapter == "lora" and loftq_bits:
|
752 |
skip_prepare_model_for_kbit_training = True
|
753 |
|
754 |
+
if qlora_fsdp or (cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading):
|
755 |
+
# make sure everything is in the same dtype
|
756 |
skip_prepare_model_for_kbit_training = True
|
757 |
|
758 |
if cfg.adapter in ["lora", "qlora"]:
|
|
|
793 |
cfg.ddp
|
794 |
and not load_in_8bit
|
795 |
and not (cfg.rl and cfg.load_in_4bit)
|
796 |
+
and not skip_move_to_device
|
797 |
):
|
798 |
# TODO revaldate this conditional
|
799 |
model.to(f"cuda:{cfg.local_rank}")
|
|
|
949 |
|
950 |
rank = int(os.environ.get("LOCAL_RANK", 0))
|
951 |
|
952 |
+
if (
|
953 |
+
cfg.fsdp
|
954 |
+
and cfg.adapter
|
955 |
+
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
956 |
+
and rank != 0
|
957 |
+
):
|
958 |
setup_quantized_meta_for_peft(model)
|
959 |
|
960 |
if cfg.lora_model_dir:
|
|
|
979 |
LOG.warning(
|
980 |
"Exception caught during model.print_trainable_parameters(): %s", exc
|
981 |
)
|
982 |
+
elif (
|
983 |
+
cfg.fsdp
|
984 |
+
and cfg.adapter
|
985 |
+
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
986 |
+
and rank != 0
|
987 |
+
):
|
988 |
setup_quantized_peft_meta_for_training(model)
|
989 |
|
990 |
return model, lora_config
|
src/axolotl/utils/trainer.py
CHANGED
@@ -306,6 +306,8 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|
306 |
|
307 |
def setup_fsdp_envs(cfg):
|
308 |
os.environ["ACCELERATE_USE_FSDP"] = "true"
|
|
|
|
|
309 |
if cfg.fsdp_config.fsdp_offload_params:
|
310 |
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
|
311 |
if cfg.fsdp_config.fsdp_sync_module_states:
|
|
|
306 |
|
307 |
def setup_fsdp_envs(cfg):
|
308 |
os.environ["ACCELERATE_USE_FSDP"] = "true"
|
309 |
+
if cfg.fsdp_config.fsdp_activation_checkpointing:
|
310 |
+
os.environ["FSDP_ACTIVATION_CHECKPOINTING"] = "true"
|
311 |
if cfg.fsdp_config.fsdp_offload_params:
|
312 |
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
|
313 |
if cfg.fsdp_config.fsdp_sync_module_states:
|