winglian commited on
Commit
5e616d9
·
unverified ·
2 Parent(s): cd0a6f6 94f310c

Merge branch 'main' into strip-peft-device-map

Browse files
FAQS.md CHANGED
@@ -2,3 +2,6 @@
2
 
3
  - Can you train StableLM with this? Yes, but only with a single GPU atm. Multi GPU support is coming soon! Just waiting on this [PR](https://github.com/huggingface/transformers/pull/22874)
4
  - Will this work with Deepspeed? That's still a WIP, but setting `export ACCELERATE_USE_DEEPSPEED=true` should work in some cases
 
 
 
 
2
 
3
  - Can you train StableLM with this? Yes, but only with a single GPU atm. Multi GPU support is coming soon! Just waiting on this [PR](https://github.com/huggingface/transformers/pull/22874)
4
  - Will this work with Deepspeed? That's still a WIP, but setting `export ACCELERATE_USE_DEEPSPEED=true` should work in some cases
5
+ - `Error invalid argument at line 359 in file /workspace/bitsandbytes/csrc/pythonInterface.c`
6
+ `/arrow/cpp/src/arrow/filesystem/s3fs.cc:2598: arrow::fs::FinalizeS3 was not called even though S3 was initialized.`
7
+ This could lead to a segmentation fault at exit. Try reinstalling bitsandbytes and transformers from source.
README.md CHANGED
@@ -16,13 +16,14 @@
16
 
17
  ## Axolotl supports
18
 
19
- | | fp16/fp32 | fp16/fp32 w/ lora | qlora | 4bit-quant | 4bit-quant w/flash attention | flash attention | xformers attention |
20
- |---------|:----------|:------------------|------|------------|------------------------------|-----------------|--------------------|
21
- | llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
22
- | Pythia | ✅ | ✅ | | ❌ | ❌ | ❌ | ❓ |
23
- | cerebras | ✅ | ✅ | | ❌ | ❌ | ❌ | |
24
- | mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
25
- | falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | |
 
26
 
27
 
28
  ## Quickstart ⚡
@@ -38,10 +39,10 @@ pip3 install -U git+https://github.com/huggingface/peft.git
38
  accelerate config
39
 
40
  # finetune lora
41
- accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml
42
 
43
  # inference
44
- accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
45
  --inference --lora_model_dir="./lora-out"
46
  ```
47
 
@@ -381,6 +382,8 @@ num_epochs: 3
381
  warmup_steps: 100
382
  learning_rate: 0.00003
383
  logging_steps:
 
 
384
 
385
  # whether to mask out or include the human's prompt from the training labels
386
  train_on_inputs: false
 
16
 
17
  ## Axolotl supports
18
 
19
+ | | fp16/fp32 | lora | qlora | gptq | gptq w/ lora | gptq w/flash attn | flash attn | xformers attn |
20
+ |----------|:----------|:-----|-------|------|:-------------|-------------------|------------|---------------|
21
+ | llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
22
+ | Pythia | ✅ | ✅ | | ❌ | ❓ | | ❌ | ❓ |
23
+ | cerebras | ✅ | ✅ | | ❌ | ❓ | | ❌ | |
24
+ | mpt | ✅ | ❌ | ❓ | ❌ | ❓ | | ❌ | ❓ |
25
+ | falcon | ✅ | ✅ | ✅ | ❌ | ❓ | | ❌ | |
26
+ | gpt-j | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❓ | ✅ |
27
 
28
 
29
  ## Quickstart ⚡
 
39
  accelerate config
40
 
41
  # finetune lora
42
+ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml
43
 
44
  # inference
45
+ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
46
  --inference --lora_model_dir="./lora-out"
47
  ```
48
 
 
382
  warmup_steps: 100
383
  learning_rate: 0.00003
384
  logging_steps:
385
+ save_steps:
386
+ eval_steps:
387
 
388
  # whether to mask out or include the human's prompt from the training labels
389
  train_on_inputs: false
configs/accelerate/default_config.yaml DELETED
@@ -1,15 +0,0 @@
1
- compute_environment: LOCAL_MACHINE
2
- distributed_type: 'NO'
3
- downcast_bf16: 'no'
4
- gpu_ids: all
5
- machine_rank: 0
6
- main_training_function: main
7
- mixed_precision: bf16
8
- num_machines: 1
9
- num_processes: 1
10
- rdzv_backend: static
11
- same_network: true
12
- tpu_env: []
13
- tpu_use_cluster: false
14
- tpu_use_sudo: false
15
- use_cpu: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/cerebras_1_3B_alpaca.yml DELETED
@@ -1,40 +0,0 @@
1
- base_model: cerebras/Cerebras-GPT-1.3B
2
- model_type: AutoModelForCausalLM
3
- tokenizer_type: AutoTokenizer
4
- load_in_8bit: true
5
- datasets:
6
- - path: data/alpaca_data_gpt4.jsonl
7
- type: alpaca
8
- - path: data/vicuna_cleaned.jsonl
9
- type: sharegpt
10
- - path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
11
- type: gpteacher
12
- - path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
13
- type: gpteacher
14
- dataset_prepared_path: last_run_prepared
15
- val_set_size: 0.05
16
- adapter: lora
17
- sequence_len: 2048
18
- lora_r: 8
19
- lora_alpha: 16
20
- lora_dropout: 0.05
21
- lora_target_modules:
22
- - c_attn
23
- lora_fan_in_fan_out: false
24
- wandb_project: pythia-1.4b-lora
25
- wandb_watch:
26
- wandb_run_id:
27
- wandb_log_model:
28
- output_dir: ./lora-alpaca
29
- gradient_accumulation_steps: 1
30
- micro_batch_size: 4
31
- num_epochs: 5
32
- learning_rate: 0.0003
33
- train_on_inputs: false
34
- group_by_length: false
35
- bf16: True
36
- tf32: True
37
- gradient_checkpointing:
38
- early_stopping_patience:
39
- resume_from_checkpoint:
40
- local_rank:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/galactica_1_3B.yml DELETED
@@ -1,41 +0,0 @@
1
- base_model: facebook/galactica-1.3b
2
- model_type: AutoModelForCausalLM
3
- tokenizer_type: AutoTokenizer
4
- load_in_8bit: false
5
- datasets:
6
- - path: tatsu-lab/alpaca
7
- type: alpaca
8
- dataset_prepared_path: last_run_prepared
9
- val_set_size: 0.1
10
- adapter:
11
- lora_model_dir:
12
- sequence_len: 1024
13
- max_packed_sequence_len: 1024
14
- lora_r: 8
15
- lora_alpha: 16
16
- lora_dropout: 0.05
17
- lora_target_modules:
18
- - q_proj
19
- - v_proj
20
- lora_fan_in_fan_out: false
21
- wandb_project:
22
- wandb_watch:
23
- wandb_run_id:
24
- wandb_log_model:
25
- output_dir: ./lora-llama-alpaca
26
- gradient_accumulation_steps: 1
27
- micro_batch_size: 16
28
- num_epochs: 3
29
- learning_rate: 0.00003
30
- train_on_inputs: false
31
- group_by_length: false
32
- bf16: false
33
- tf32: false
34
- early_stopping_patience:
35
- resume_from_checkpoint:
36
- local_rank:
37
- tokens:
38
- pad_token: "[PAD]"
39
- bos_token: "<s>"
40
- eos_token: "</s>"
41
- unk_token: "<unk>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/gpt_neox_20b.yml DELETED
@@ -1,39 +0,0 @@
1
- base_model: EleutherAI/gpt-neox-20b
2
- base_model_ignore_patterns: pytorch* # prefer safetensors
3
- model_type: GPTNeoXForCausalLM
4
- tokenizer_type: AutoTokenizer
5
- load_in_8bit: true
6
- datasets:
7
- - path: nomic-ai/gpt4all-j-prompt-generations
8
- type: alpaca
9
- shards: 4
10
- shards_index: 0
11
- dataset_prepared_path: last_run_prepared
12
- val_set_size: 0.05
13
- adapter: lora
14
- lora_model_dir:
15
- sequence_len: 2048
16
- max_packed_sequence_len: 2048
17
- lora_r: 8
18
- lora_alpha: 32
19
- lora_dropout: 0.05
20
- lora_target_modules:
21
- - query_key_value
22
- lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
23
- wandb_project: gpt4all-neox-20b
24
- wandb_watch:
25
- wandb_run_id:
26
- wandb_log_model:
27
- output_dir: ./gpt4all-neox-20b
28
- gradient_accumulation_steps: 1
29
- micro_batch_size: 4
30
- num_epochs: 5
31
- learning_rate: 0.00003
32
- lr_scheduler: one_cycle
33
- train_on_inputs: false
34
- group_by_length: false
35
- bf16: True
36
- tf32: True
37
- early_stopping_patience:
38
- resume_from_checkpoint:
39
- local_rank:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/llama_13B_alpaca.yml DELETED
@@ -1,39 +0,0 @@
1
- base_model: huggyllama/llama-13b
2
- model_type: LlamaForCausalLM
3
- tokenizer_type: LlamaTokenizer
4
- load_in_8bit: true
5
- datasets:
6
- - path: anon8231489123/ShareGPT_Vicuna_unfiltered
7
- data_files: ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json
8
- type: sharegpt
9
- dataset_prepared_path: last_run_prepared
10
- val_set_size: 0.002
11
- adapter:
12
- lora_model_dir:
13
- sequence_len: 2048
14
- lora_r: 8
15
- lora_alpha: 16
16
- lora_dropout: 0.05
17
- lora_target_modules:
18
- - q_proj
19
- - v_proj
20
- lora_fan_in_fan_out: false
21
- wandb_project:
22
- wandb_watch:
23
- wandb_run_id:
24
- wandb_log_model:
25
- output_dir: ./llama-13b-sharegpt
26
- gradient_accumulation_steps: 1
27
- micro_batch_size: 2
28
- warmup_steps: 1000
29
- save_steps:
30
- eval_steps:
31
- num_epochs: 5
32
- learning_rate: 0.00003
33
- train_on_inputs: false
34
- group_by_length: false
35
- bf16: true
36
- tf32: true
37
- early_stopping_patience: 5
38
- resume_from_checkpoint:
39
- local_rank:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/llama_65B_alpaca.yml DELETED
@@ -1,44 +0,0 @@
1
- base_model: huggyllama/llama-65b
2
- model_type: LlamaForCausalLM
3
- tokenizer_type: LlamaTokenizer
4
- load_in_8bit: true
5
- datasets:
6
- - path: data/alpaca_data_gpt4.jsonl
7
- type: alpaca
8
- - path: anon8231489123/ShareGPT_Vicuna_unfiltered
9
- data_files: ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json
10
- type: sharegpt
11
- - path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
12
- type: gpteacher
13
- - path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
14
- type: gpteacher
15
- dataset_prepared_path: last_run_prepared
16
- val_set_size: 0.04
17
- adapter: lora
18
- lora_model_dir:
19
- sequence_len: 2048
20
- lora_r: 8
21
- lora_alpha: 16
22
- lora_dropout: 0.05
23
- lora_target_modules:
24
- - q_proj
25
- - v_proj
26
- lora_fan_in_fan_out: false
27
- wandb_project: llama-65b-lora
28
- wandb_watch:
29
- wandb_run_id:
30
- wandb_log_model:
31
- output_dir: ./lora-llama-alpaca
32
- gradient_accumulation_steps: 1
33
- micro_batch_size: 16
34
- warmup_steps: 1000
35
- save_steps:
36
- num_epochs: 5
37
- learning_rate: 0.00003
38
- train_on_inputs: false
39
- group_by_length: false
40
- bf16: true
41
- tf32: true
42
- early_stopping_patience:
43
- resume_from_checkpoint:
44
- local_rank:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/llama_7B_4bit.yml DELETED
@@ -1,45 +0,0 @@
1
- base_model: decapoda-research/llama-7b-hf-int4
2
- base_model_config: decapoda-research/llama-7b-hf
3
- model_type: LlamaForCausalLM
4
- tokenizer_type: LlamaTokenizer
5
- load_in_8bit: true
6
- datasets:
7
- - path: tatsu-lab/alpaca # original alpaca dataset
8
- type: alpaca
9
- dataset_prepared_path: data/last_run_prepared
10
- val_set_size: 0.04
11
- adapter: lora
12
- lora_model_dir:
13
- sequence_len: 2048
14
- max_packed_sequence_len: 1024
15
- lora_r: 8
16
- lora_alpha: 16
17
- lora_dropout: 0.05
18
- lora_target_modules:
19
- - q_proj
20
- - v_proj
21
- # - k_proj
22
- # - o_proj
23
- lora_fan_in_fan_out: false
24
- wandb_project:
25
- wandb_watch:
26
- wandb_run_id:
27
- wandb_log_model:
28
- output_dir: ./lora-test
29
- gradient_accumulation_steps: 1
30
- micro_batch_size: 2
31
- num_epochs: 3
32
- warmup_steps: 100
33
- learning_rate: 0.00003
34
- train_on_inputs: false
35
- group_by_length: false
36
- bf16: true
37
- tf32: true
38
- gradient_checkpointing: false
39
- early_stopping_patience: 3
40
- resume_from_checkpoint:
41
- auto_resume_from_checkpoints: true
42
- local_rank:
43
- load_4bit: true
44
- xformers_attention: true
45
- flash_attention:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/llama_7B_alpaca.yml DELETED
@@ -1,41 +0,0 @@
1
- base_model: huggyllama/llama-7b
2
- model_type: LlamaForCausalLM
3
- tokenizer_type: LlamaTokenizer
4
- load_in_8bit: true
5
- datasets:
6
- - path: data/alpaca_data_gpt4.jsonl
7
- type: alpaca
8
- - path: data/vicuna_cleaned.jsonl
9
- type: sharegpt
10
- - path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
11
- type: gpteacher
12
- - path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
13
- type: gpteacher
14
- dataset_prepared_path: last_run_prepared
15
- val_set_size: 0.04
16
- adapter: lora
17
- lora_model_dir:
18
- sequence_len: 2048
19
- lora_r: 8
20
- lora_alpha: 16
21
- lora_dropout: 0.05
22
- lora_target_modules:
23
- - q_proj
24
- - v_proj
25
- lora_fan_in_fan_out: false
26
- wandb_project: llama-7b-lora
27
- wandb_watch:
28
- wandb_run_id:
29
- wandb_log_model:
30
- output_dir: ./lora-llama-alpaca
31
- gradient_accumulation_steps: 1
32
- micro_batch_size: 16
33
- num_epochs: 5
34
- learning_rate: 0.00003
35
- train_on_inputs: false
36
- group_by_length: false
37
- bf16: true
38
- tf32: true
39
- early_stopping_patience:
40
- resume_from_checkpoint:
41
- local_rank:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/quickstart.yml DELETED
@@ -1,45 +0,0 @@
1
- base_model: decapoda-research/llama-7b-hf-int4
2
- base_model_config: decapoda-research/llama-7b-hf
3
- model_type: LlamaForCausalLM
4
- tokenizer_type: LlamaTokenizer
5
- load_in_8bit: true
6
- datasets:
7
- - path: tatsu-lab/alpaca # original alpaca dataset
8
- type: alpaca
9
- dataset_prepared_path: data/last_run_prepared
10
- val_set_size: 0.04
11
- adapter: lora
12
- lora_model_dir:
13
- sequence_len: 1024
14
- max_packed_sequence_len: 1024
15
- lora_r: 8
16
- lora_alpha: 16
17
- lora_dropout: 0.05
18
- lora_target_modules:
19
- - q_proj
20
- - v_proj
21
- # - k_proj
22
- # - o_proj
23
- lora_fan_in_fan_out: false
24
- wandb_project:
25
- wandb_watch:
26
- wandb_run_id:
27
- wandb_log_model:
28
- output_dir: ./lora-test
29
- gradient_accumulation_steps: 1
30
- micro_batch_size: 1
31
- num_epochs: 3
32
- warmup_steps: 100
33
- learning_rate: 0.00003
34
- train_on_inputs: false
35
- group_by_length: false
36
- bf16: true
37
- tf32: true
38
- gradient_checkpointing: false
39
- early_stopping_patience: 3
40
- resume_from_checkpoint:
41
- auto_resume_from_checkpoints: true
42
- local_rank:
43
- gptq: true
44
- xformers_attention: true
45
- flash_attention:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/sample.yml DELETED
@@ -1,87 +0,0 @@
1
- # this is the huggingface model that contains *.pt, *.safetensors, or *.bin files
2
- # this can also be a relative path to a model on disk
3
- base_model: decapoda-research/llama-7b-hf-int4
4
- # you can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
5
- base_model_ignore_patterns:
6
- # if the base_model repo on hf hub doesn't include configuration .json files,
7
- # you can set that here, or leave this empty to default to base_model
8
- base_model_config: decapoda-research/llama-7b-hf
9
- # If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
10
- model_type: AutoModelForCausalLM
11
- # Corresponding tokenizer for the model AutoTokenizer is a good choice
12
- tokenizer_type: AutoTokenizer
13
- # whether you are training a 4-bit quantized model
14
- load_4bit: true
15
- # this will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
16
- load_in_8bit: true
17
- # a list of one or more datasets to finetune the model with
18
- datasets:
19
- # this can be either a hf dataset, or relative path
20
- - path: vicgalle/alpaca-gpt4
21
- # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
22
- type: alpaca
23
- # axolotl attempts to save the dataset as an arrow after packing the data together so
24
- # subsequent training attempts load faster, relative path
25
- dataset_prepared_path: data/last_run_prepared
26
- # How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc
27
- val_set_size: 0.04
28
- # if you want to use lora, leave blank to train all parameters in original model
29
- adapter: lora
30
- # if you already have a lora model trained that you want to load, put that here
31
- lora_model_dir:
32
- # the maximum length of an input to train with, this should typically be less than 2048
33
- # as most models have a token/context limit of 2048
34
- sequence_len: 2048
35
- # max sequence length to concatenate training samples together up to
36
- # inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
37
- max_packed_sequence_len: 1024
38
- # lora hyperparameters
39
- lora_r: 8
40
- lora_alpha: 16
41
- lora_dropout: 0.05
42
- lora_target_modules:
43
- - q_proj
44
- - v_proj
45
- # - k_proj
46
- # - o_proj
47
- lora_fan_in_fan_out: false
48
- # wandb configuration if your're using it
49
- wandb_project:
50
- wandb_watch:
51
- wandb_run_id:
52
- wandb_log_model:
53
- # where to save the finsihed model to
54
- output_dir: ./completed-model
55
- # training hyperparameters
56
- gradient_accumulation_steps: 1
57
- batch_size:
58
- micro_batch_size: 2
59
- num_epochs: 3
60
- warmup_steps: 100
61
- learning_rate: 0.00003
62
- # whether to mask out or include the human's prompt from the training labels
63
- train_on_inputs: false
64
- # don't use this, leads to wonky training (according to someone on the internet)
65
- group_by_length: false
66
- # Use CUDA bf16
67
- bf16: true
68
- # Use CUDA tf32
69
- tf32: true
70
- # does not work with current implementation of 4-bit LoRA
71
- gradient_checkpointing: false
72
- # stop training after this many evaluation losses have increased in a row
73
- # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
74
- early_stopping_patience: 3
75
- # specify a scheduler to use with the optimizer. only one_cycle is supported currently
76
- lr_scheduler:
77
- # whether to use xformers attention patch https://github.com/facebookresearch/xformers:
78
- xformers_attention:
79
- # whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
80
- flash_attention:
81
- # resume from a specific checkpoint dir
82
- resume_from_checkpoint:
83
- # if resume_from_checkpoint isn't set and you simply want it to start where it left off
84
- # be careful with this being turned on between different models
85
- auto_resume_from_checkpoints: false
86
- # don't mess with this, it's here for accelerate and torchrun
87
- local_rank:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/vicuna_13B_4bit_reflect.yml DELETED
@@ -1,45 +0,0 @@
1
- base_model: anon8231489123/vicuna-13b-GPTQ-4bit-128g
2
- base_model_config: anon8231489123/vicuna-13b-GPTQ-4bit-128g
3
- model_type: LlamaForCausalLM
4
- tokenizer_type: LlamaTokenizer
5
- load_in_8bit: false
6
- load_4bit: true
7
- gptq_groupsize: 128
8
- gptq_model_v1: false
9
- datasets:
10
- # https://github.com/vaguenebula/AlpacaDataReflect/blob/main/alpaca_reflect_pruned.json
11
- - path: data/alpaca_reflect_pruned.jsonl
12
- type: reflection
13
- dataset_prepared_path: data/last_run_prepared
14
- val_set_size: 0.04
15
- adapter: lora
16
- lora_model_dir:
17
- sequence_len: 2048
18
- max_packed_sequence_len: 2048
19
- lora_r: 8
20
- lora_alpha: 16
21
- lora_dropout: 0.05
22
- lora_target_modules:
23
- - q_proj
24
- - v_proj
25
- # - k_proj
26
- # - o_proj
27
- lora_fan_in_fan_out: false
28
- wandb_project:
29
- wandb_watch:
30
- wandb_run_id:
31
- wandb_log_model:
32
- output_dir: ./lora-reflect
33
- gradient_accumulation_steps: 1
34
- micro_batch_size: 2
35
- num_epochs: 3
36
- learning_rate: 0.00003
37
- train_on_inputs: false
38
- group_by_length: false
39
- bf16: true
40
- tf32: true
41
- gradient_checkpointing: false
42
- early_stopping_patience: 3
43
- resume_from_checkpoint:
44
- local_rank:
45
- flash_attention: true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/cerebras/qlora.yml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: cerebras/Cerebras-GPT-1.3B
2
+ base_model_config: cerebras/Cerebras-GPT-1.3B
3
+ load_in_8bit: false
4
+ load_in_4bit: true
5
+ strict: false
6
+ push_dataset_to_hub:
7
+ datasets:
8
+ - path: teknium/GPT4-LLM-Cleaned
9
+ type: alpaca
10
+ dataset_prepared_path: last_run_prepared
11
+ val_set_size: 0.01
12
+ adapter: qlora
13
+ lora_model_dir:
14
+ sequence_len: 2048
15
+ max_packed_sequence_len: 2048
16
+ lora_r: 16
17
+ lora_alpha: 32
18
+ lora_dropout: 0.05
19
+ lora_target_modules:
20
+ - c_fc
21
+ - c_attn
22
+ - c_proj
23
+ lora_target_linear:
24
+ lora_fan_in_fan_out:
25
+ wandb_project:
26
+ wandb_watch:
27
+ wandb_run_id:
28
+ wandb_log_model:
29
+ output_dir: ./qlora-out
30
+ batch_size: 4
31
+ micro_batch_size: 4
32
+ num_epochs: 2
33
+ optimizer: paged_adamw_8bit
34
+ torchdistx_path:
35
+ lr_scheduler: cosine
36
+ learning_rate: 0.0002
37
+ train_on_inputs: false
38
+ group_by_length: true
39
+ bf16: true
40
+ fp16: false
41
+ tf32: true
42
+ gradient_checkpointing: true
43
+ early_stopping_patience:
44
+ resume_from_checkpoint:
45
+ local_rank:
46
+ logging_steps: 1
47
+ xformers_attention: true
48
+ flash_attention:
49
+ gptq_groupsize:
50
+ gptq_model_v1:
51
+ warmup_steps: 10
52
+ eval_steps: 20
53
+ save_steps:
54
+ debug:
55
+ deepspeed:
56
+ weight_decay: 0.1
57
+ fsdp:
58
+ fsdp_config:
59
+ special_tokens:
60
+ pad_token: "<|endoftext|>"
examples/falcon/config-7b-lora.yml CHANGED
@@ -23,7 +23,7 @@ lora_dropout: 0.0
23
  lora_target_modules:
24
  lora_target_linear: true
25
  lora_fan_in_fan_out:
26
- wandb_project: falcon-7b
27
  wandb_watch:
28
  wandb_run_id:
29
  wandb_log_model:
 
23
  lora_target_modules:
24
  lora_target_linear: true
25
  lora_fan_in_fan_out:
26
+ wandb_project:
27
  wandb_watch:
28
  wandb_run_id:
29
  wandb_log_model:
examples/falcon/config-7b.yml CHANGED
@@ -23,7 +23,7 @@ lora_dropout: 0.0
23
  lora_target_modules:
24
  lora_target_linear: true
25
  lora_fan_in_fan_out:
26
- wandb_project: falcon-7b
27
  wandb_watch:
28
  wandb_run_id:
29
  wandb_log_model:
 
23
  lora_target_modules:
24
  lora_target_linear: true
25
  lora_fan_in_fan_out:
26
+ wandb_project:
27
  wandb_watch:
28
  wandb_run_id:
29
  wandb_log_model:
configs/stability_3b.yml → examples/gptj/qlora.yml RENAMED
@@ -1,38 +1,42 @@
1
- base_model: stabilityai/stablelm-base-alpha-3b
2
- base_model_config: stabilityai/stablelm-base-alpha-3b
3
  load_in_8bit: false
 
 
 
4
  datasets:
5
- - path: vicgalle/alpaca-gpt4
6
  type: alpaca
7
  dataset_prepared_path: last_run_prepared
8
- val_set_size: 0.04
9
- adapter:
10
  lora_model_dir:
11
- sequence_len: 4096
12
- max_packed_sequence_len: 4096
13
  lora_r: 8
14
- lora_alpha: 16
15
  lora_dropout: 0.05
16
  lora_target_modules:
17
- - q_proj
18
- - v_proj
19
- lora_fan_in_fan_out: false
20
- wandb_project: stable-alpaca-3b
21
  wandb_watch:
22
  wandb_run_id:
23
  wandb_log_model:
24
- output_dir: ./stable-alpaca-3b
25
- gradient_accumulation_steps: 1
26
- micro_batch_size: 1
27
- num_epochs: 1
28
- optimizer: adamw_bnb_8bit
29
  torchdistx_path:
30
  lr_scheduler: cosine
31
- learning_rate: 0.0000002
32
  train_on_inputs: false
33
- group_by_length: false
34
  bf16: true
 
35
  tf32: true
 
36
  early_stopping_patience:
37
  resume_from_checkpoint:
38
  local_rank:
@@ -41,16 +45,13 @@ xformers_attention: true
41
  flash_attention:
42
  gptq_groupsize:
43
  gptq_model_v1:
44
- warmup_steps: 100
45
- eval_steps: 50
46
- save_steps: 200
47
  debug:
48
  deepspeed:
49
- weight_decay: 0.01
50
  fsdp:
51
  fsdp_config:
52
- #tokens:
53
- # pad_token: "[PAD]"
54
- # bos_token: "<s>"
55
- # eos_token: "</s>"
56
- # unk_token: "<unk>"
 
1
+ base_model: EleutherAI/gpt-j-6b
2
+ base_model_config: EleutherAI/gpt-j-6b
3
  load_in_8bit: false
4
+ load_in_4bit: true
5
+ strict: false
6
+ push_dataset_to_hub:
7
  datasets:
8
+ - path: teknium/GPT4-LLM-Cleaned
9
  type: alpaca
10
  dataset_prepared_path: last_run_prepared
11
+ val_set_size: 0.01
12
+ adapter: qlora
13
  lora_model_dir:
14
+ sequence_len: 2048
15
+ max_packed_sequence_len:
16
  lora_r: 8
17
+ lora_alpha: 32
18
  lora_dropout: 0.05
19
  lora_target_modules:
20
+ lora_target_linear: true
21
+ lora_fan_in_fan_out:
22
+ wandb_project:
 
23
  wandb_watch:
24
  wandb_run_id:
25
  wandb_log_model:
26
+ output_dir: ./qlora-out
27
+ gradient_accumulation_steps: 2
28
+ micro_batch_size: 2
29
+ num_epochs: 2
30
+ optimizer: paged_adamw_8bit
31
  torchdistx_path:
32
  lr_scheduler: cosine
33
+ learning_rate: 0.0001
34
  train_on_inputs: false
35
+ group_by_length: true
36
  bf16: true
37
+ fp16: false
38
  tf32: true
39
+ gradient_checkpointing: true
40
  early_stopping_patience:
41
  resume_from_checkpoint:
42
  local_rank:
 
45
  flash_attention:
46
  gptq_groupsize:
47
  gptq_model_v1:
48
+ warmup_steps: 10
49
+ eval_steps: 20
50
+ save_steps:
51
  debug:
52
  deepspeed:
53
+ weight_decay: 0.1
54
  fsdp:
55
  fsdp_config:
56
+ special_tokens:
57
+ pad_token: "<|endoftext|>"
 
 
 
examples/gptq-lora-7b/README.md CHANGED
@@ -3,6 +3,6 @@
3
  This is a good place to start for beginners. This will run on an NVIDIA RTX4090 with no other changes needed.
4
 
5
  ```shell
6
- accelerate launch scripts/finetune.py examples/4bit-lora-7b/config.yml
7
 
8
  ```
 
3
  This is a good place to start for beginners. This will run on an NVIDIA RTX4090 with no other changes needed.
4
 
5
  ```shell
6
+ accelerate launch scripts/finetune.py examples/gptq-lora-7b/config.yml
7
 
8
  ```
configs/llama_7B_jeopardy.yml → examples/jeopardy-bot/config.yml RENAMED
@@ -7,30 +7,28 @@ datasets:
7
  - path: openaccess-ai-collective/jeopardy
8
  type: jeopardy
9
  dataset_prepared_path: last_run_prepared
10
- val_set_size: 0.01
11
  adapter:
12
  lora_model_dir:
13
- sequence_len: 2048
14
- max_packed_sequence_len: 2048
15
- lora_r: 8
16
- lora_alpha: 16
17
- lora_dropout: 0.05
18
  lora_target_modules:
19
- - q_proj
20
- - v_proj
21
  lora_fan_in_fan_out: false
22
- wandb_project: jeopardy-bot-7b
23
  wandb_watch:
24
  wandb_run_id:
25
  wandb_log_model:
26
  output_dir: ./jeopardy-bot-7b
27
- gradient_accumulation_steps: 2
28
  micro_batch_size: 1
29
- num_epochs: 2
30
  optimizer: adamw_bnb_8bit
31
  torchdistx_path:
32
  lr_scheduler: cosine
33
- learning_rate: 0.0000002
34
  train_on_inputs: false
35
  group_by_length: false
36
  bf16: true
@@ -48,11 +46,10 @@ eval_steps: 110
48
  save_steps: 660
49
  debug:
50
  deepspeed:
51
- weight_decay: 0.0001
52
  fsdp:
53
  fsdp_config:
54
  tokens:
55
- pad_token: "[PAD]"
56
  bos_token: "<s>"
57
  eos_token: "</s>"
58
  unk_token: "<unk>"
 
7
  - path: openaccess-ai-collective/jeopardy
8
  type: jeopardy
9
  dataset_prepared_path: last_run_prepared
10
+ val_set_size: 0.02
11
  adapter:
12
  lora_model_dir:
13
+ sequence_len: 512
14
+ max_packed_sequence_len:
15
+ lora_r:
16
+ lora_alpha:
17
+ lora_dropout:
18
  lora_target_modules:
 
 
19
  lora_fan_in_fan_out: false
20
+ wandb_project:
21
  wandb_watch:
22
  wandb_run_id:
23
  wandb_log_model:
24
  output_dir: ./jeopardy-bot-7b
25
+ gradient_accumulation_steps: 1
26
  micro_batch_size: 1
27
+ num_epochs: 3
28
  optimizer: adamw_bnb_8bit
29
  torchdistx_path:
30
  lr_scheduler: cosine
31
+ learning_rate: 0.00003
32
  train_on_inputs: false
33
  group_by_length: false
34
  bf16: true
 
46
  save_steps: 660
47
  debug:
48
  deepspeed:
49
+ weight_decay: 0.1
50
  fsdp:
51
  fsdp_config:
52
  tokens:
 
53
  bos_token: "<s>"
54
  eos_token: "</s>"
55
  unk_token: "<unk>"
examples/openllama-3b/README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # openllama-3b
2
+
3
+ Basic full tune
4
+ ```shell
5
+ accelerate launch scripts/finetune.py examples/openllama-3b/config.yml
6
+ ```
7
+
8
+ LoRA
9
+ ```shell
10
+ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml
11
+ ```
12
+
13
+ QLoRA
14
+ ```shell
15
+ accelerate launch scripts/finetune.py examples/openllama-3b/qlora.yml
16
+ ```
examples/openllama-3b/config.yml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: openlm-research/open_llama_3b
2
+ base_model_config: openlm-research/open_llama_3b
3
+ model_type: LlamaForCausalLM
4
+ tokenizer_type: LlamaTokenizer
5
+ load_in_8bit: false
6
+ load_in_4bit: false
7
+ strict: false
8
+ push_dataset_to_hub:
9
+ datasets:
10
+ - path: teknium/GPT4-LLM-Cleaned
11
+ type: alpaca
12
+ dataset_prepared_path: last_run_prepared
13
+ val_set_size: 0.02
14
+ adapter:
15
+ lora_model_dir:
16
+ sequence_len: 256
17
+ max_packed_sequence_len:
18
+ lora_r:
19
+ lora_alpha:
20
+ lora_dropout:
21
+ lora_target_modules:
22
+ lora_target_linear:
23
+ lora_fan_in_fan_out:
24
+ wandb_project:
25
+ wandb_watch:
26
+ wandb_run_id:
27
+ wandb_log_model:
28
+ output_dir: ./openllama-out
29
+ batch_size: 16
30
+ micro_batch_size: 4
31
+ num_epochs: 3
32
+ optimizer: adamw_bnb_8bit
33
+ torchdistx_path:
34
+ lr_scheduler: cosine
35
+ learning_rate: 0.0002
36
+ train_on_inputs: false
37
+ group_by_length: false
38
+ bf16: false
39
+ fp16: true
40
+ tf32: false
41
+ gradient_checkpointing: true
42
+ early_stopping_patience:
43
+ resume_from_checkpoint:
44
+ local_rank:
45
+ logging_steps: 1
46
+ xformers_attention: true
47
+ flash_attention:
48
+ gptq_groupsize:
49
+ gptq_model_v1:
50
+ warmup_steps: 10
51
+ eval_steps: 50
52
+ save_steps:
53
+ debug:
54
+ deepspeed:
55
+ weight_decay: 0.0
56
+ fsdp:
57
+ fsdp_config:
58
+ special_tokens:
59
+ bos_token: "<s>"
60
+ eos_token: "</s>"
61
+ unk_token: "<unk>"
examples/{lora-openllama-3b/config.yml → openllama-3b/lora.yml} RENAMED
@@ -1,5 +1,5 @@
1
- base_model: openlm-research/open_llama_3b_600bt_preview
2
- base_model_config: openlm-research/open_llama_3b_600bt_preview
3
  model_type: LlamaForCausalLM
4
  tokenizer_type: LlamaTokenizer
5
  load_in_8bit: true
@@ -49,7 +49,7 @@ early_stopping_patience:
49
  resume_from_checkpoint:
50
  local_rank:
51
  logging_steps: 1
52
- xformers_attention:
53
  flash_attention:
54
  gptq_groupsize:
55
  gptq_model_v1:
 
1
+ base_model: openlm-research/open_llama_3b
2
+ base_model_config: openlm-research/open_llama_3b
3
  model_type: LlamaForCausalLM
4
  tokenizer_type: LlamaTokenizer
5
  load_in_8bit: true
 
49
  resume_from_checkpoint:
50
  local_rank:
51
  logging_steps: 1
52
+ xformers_attention: true
53
  flash_attention:
54
  gptq_groupsize:
55
  gptq_model_v1:
examples/{qlora-openllama-3b/config.yml → openllama-3b/qlora.yml} RENAMED
@@ -1,5 +1,5 @@
1
- base_model: openlm-research/open_llama_3b_600bt_preview
2
- base_model_config: openlm-research/open_llama_3b_600bt_preview
3
  model_type: LlamaForCausalLM
4
  tokenizer_type: LlamaTokenizer
5
  load_in_8bit: false
 
1
+ base_model: openlm-research/open_llama_3b
2
+ base_model_config: openlm-research/open_llama_3b
3
  model_type: LlamaForCausalLM
4
  tokenizer_type: LlamaTokenizer
5
  load_in_8bit: false
configs/pythia_1_2B_alpaca.yml → examples/pythia/lora.yml RENAMED
@@ -1,36 +1,29 @@
1
  base_model: EleutherAI/pythia-1.4b-deduped
2
- model_type: GPTNeoXForCausalLM
3
- tokenizer_type: AutoTokenizer
4
  load_in_8bit: true
5
  datasets:
6
- - path: data/alpaca_data_gpt4.jsonl
7
  type: alpaca
8
- - path: data/vicuna_cleaned.jsonl
9
- type: sharegpt
10
- - path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
11
- type: gpteacher
12
- - path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
13
- type: gpteacher
14
  dataset_prepared_path: last_run_prepared
15
  val_set_size: 0.05
16
  adapter: lora
17
  lora_model_dir:
18
- sequence_len: 2048
19
- lora_r: 8
20
  lora_alpha: 32
21
  lora_dropout: 0.05
22
  lora_target_modules:
23
  - query_key_value
24
- # - xxx
25
  lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
26
- wandb_project: pythia-1.4b-lora
27
  wandb_watch:
28
  wandb_run_id:
29
  wandb_log_model:
30
- output_dir: ./lora-alpaca
31
  gradient_accumulation_steps: 1
32
  micro_batch_size: 4
33
- num_epochs: 5
34
  learning_rate: 0.00001
35
  train_on_inputs: false
36
  group_by_length: false
@@ -39,3 +32,6 @@ tf32: True
39
  early_stopping_patience:
40
  resume_from_checkpoint:
41
  local_rank:
 
 
 
 
1
  base_model: EleutherAI/pythia-1.4b-deduped
2
+ base_model_config: EleutherAI/pythia-1.4b-deduped
 
3
  load_in_8bit: true
4
  datasets:
5
+ - path: teknium/GPT4-LLM-Cleaned
6
  type: alpaca
 
 
 
 
 
 
7
  dataset_prepared_path: last_run_prepared
8
  val_set_size: 0.05
9
  adapter: lora
10
  lora_model_dir:
11
+ sequence_len: 512
12
+ lora_r: 16
13
  lora_alpha: 32
14
  lora_dropout: 0.05
15
  lora_target_modules:
16
  - query_key_value
17
+ lora_target_linear:
18
  lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
19
+ wandb_project:
20
  wandb_watch:
21
  wandb_run_id:
22
  wandb_log_model:
23
+ output_dir: ./lora-alpaca-pythia
24
  gradient_accumulation_steps: 1
25
  micro_batch_size: 4
26
+ num_epochs: 3
27
  learning_rate: 0.00001
28
  train_on_inputs: false
29
  group_by_length: false
 
32
  early_stopping_patience:
33
  resume_from_checkpoint:
34
  local_rank:
35
+ weight_decay: 0.1
36
+ eval_steps: 20
37
+ logging_steps: 1
examples/qlora-openllama-3b/README.md DELETED
@@ -1,6 +0,0 @@
1
- # qlora-openllama-3b
2
-
3
- ```shell
4
- accelerate launch scripts/finetune.py examples/qlora-openllama-3b/config.yml
5
-
6
- ```
 
 
 
 
 
 
 
scripts/finetune.py CHANGED
@@ -165,7 +165,7 @@ def train(
165
  cfg_keys = cfg.keys()
166
  for k, _ in kwargs.items():
167
  # if not strict, allow writing to cfg even if it's not in the yml already
168
- if k in cfg_keys or cfg.strict is False:
169
  # handle booleans
170
  if isinstance(cfg[k], bool):
171
  cfg[k] = bool(kwargs[k])
@@ -205,8 +205,8 @@ def train(
205
  logging.info(f"loading tokenizer... {tokenizer_config}")
206
  tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
207
 
208
- if check_not_in(
209
- ["inference", "shard", "merge_lora"], kwargs
210
  ): # don't need to load dataset for these
211
  train_dataset, eval_dataset = load_prepare_datasets(
212
  tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
@@ -234,7 +234,6 @@ def train(
234
  tokenizer,
235
  cfg,
236
  adapter=cfg.adapter,
237
- inference=("inference" in kwargs),
238
  )
239
 
240
  if "merge_lora" in kwargs and cfg.adapter is not None:
@@ -247,7 +246,7 @@ def train(
247
  model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
248
  return
249
 
250
- if "inference" in kwargs:
251
  logging.info("calling do_inference function")
252
  inf_kwargs: Dict[str, Any] = {}
253
  if "prompter" in kwargs:
 
165
  cfg_keys = cfg.keys()
166
  for k, _ in kwargs.items():
167
  # if not strict, allow writing to cfg even if it's not in the yml already
168
+ if k in cfg_keys or not cfg.strict:
169
  # handle booleans
170
  if isinstance(cfg[k], bool):
171
  cfg[k] = bool(kwargs[k])
 
205
  logging.info(f"loading tokenizer... {tokenizer_config}")
206
  tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
207
 
208
+ if (
209
+ check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
210
  ): # don't need to load dataset for these
211
  train_dataset, eval_dataset = load_prepare_datasets(
212
  tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
 
234
  tokenizer,
235
  cfg,
236
  adapter=cfg.adapter,
 
237
  )
238
 
239
  if "merge_lora" in kwargs and cfg.adapter is not None:
 
246
  model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
247
  return
248
 
249
+ if cfg.inference:
250
  logging.info("calling do_inference function")
251
  inf_kwargs: Dict[str, Any] = {}
252
  if "prompter" in kwargs:
src/axolotl/utils/models.py CHANGED
@@ -77,15 +77,9 @@ def load_tokenizer(
77
 
78
 
79
  def load_model(
80
- base_model,
81
- base_model_config,
82
- model_type,
83
- tokenizer,
84
- cfg,
85
- adapter="lora",
86
- inference=False,
87
  ):
88
- # type: (str, str, str, AutoTokenizer, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
89
  """
90
  Load a model from a base model and a model type.
91
  """
@@ -98,7 +92,7 @@ def load_model(
98
  )
99
 
100
  if cfg.is_llama_derived_model and cfg.flash_attention:
101
- if cfg.device not in ["mps", "cpu"] and inference is False:
102
  from axolotl.flash_attn import replace_llama_attn_with_flash_attn
103
 
104
  logging.info("patching with flash attention")
@@ -305,7 +299,9 @@ def load_model(
305
  or (cfg.adapter == "qlora" and cfg.load_in_4bit)
306
  ):
307
  logging.info("converting PEFT model w/ prepare_model_for_kbit_training")
308
- model = prepare_model_for_kbit_training(model)
 
 
309
 
310
  model, lora_config = load_adapter(model, cfg, adapter)
311
 
@@ -436,6 +432,7 @@ def load_lora(model, cfg):
436
  model = PeftModel.from_pretrained(
437
  model,
438
  cfg.lora_model_dir,
 
439
  )
440
  else:
441
  model = get_peft_model(model, lora_config)
 
77
 
78
 
79
  def load_model(
80
+ base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
 
 
 
 
 
 
81
  ):
82
+ # type: (str, str, str, AutoTokenizer, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
83
  """
84
  Load a model from a base model and a model type.
85
  """
 
92
  )
93
 
94
  if cfg.is_llama_derived_model and cfg.flash_attention:
95
+ if cfg.device not in ["mps", "cpu"] and not cfg.inference:
96
  from axolotl.flash_attn import replace_llama_attn_with_flash_attn
97
 
98
  logging.info("patching with flash attention")
 
299
  or (cfg.adapter == "qlora" and cfg.load_in_4bit)
300
  ):
301
  logging.info("converting PEFT model w/ prepare_model_for_kbit_training")
302
+ model = prepare_model_for_kbit_training(
303
+ model, use_gradient_checkpointing=cfg.gradient_checkpointing
304
+ )
305
 
306
  model, lora_config = load_adapter(model, cfg, adapter)
307
 
 
432
  model = PeftModel.from_pretrained(
433
  model,
434
  cfg.lora_model_dir,
435
+ is_trainable=not cfg.inference,
436
  )
437
  else:
438
  model = get_peft_model(model, lora_config)
src/axolotl/utils/validation.py CHANGED
@@ -57,6 +57,11 @@ def validate_config(cfg):
57
  if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
58
  raise ValueError("FSDP is not supported for falcon models")
59
 
 
 
 
 
 
60
  # TODO
61
  # MPT 7b
62
  # https://github.com/facebookresearch/bitsandbytes/issues/25
 
57
  if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
58
  raise ValueError("FSDP is not supported for falcon models")
59
 
60
+ if (
61
+ cfg.base_model and "mpt" in cfg.base_model.lower()
62
+ ) and cfg.gradient_checkpointing:
63
+ raise ValueError("gradient_checkpointing is not supported for MPT models")
64
+
65
  # TODO
66
  # MPT 7b
67
  # https://github.com/facebookresearch/bitsandbytes/issues/25
tests/test_validation.py CHANGED
@@ -198,3 +198,17 @@ class ValidationTest(unittest.TestCase):
198
  )
199
 
200
  validate_config(cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  )
199
 
200
  validate_config(cfg)
201
+
202
+ def test_mpt_gradient_checkpointing(self):
203
+ regex_exp = r".*gradient_checkpointing is not supported for MPT models*"
204
+
205
+ # Check for lower-case
206
+ cfg = DictDefault(
207
+ {
208
+ "base_model": "mosaicml/mpt-7b",
209
+ "gradient_checkpointing": True,
210
+ }
211
+ )
212
+
213
+ with pytest.raises(ValueError, match=regex_exp):
214
+ validate_config(cfg)