winglian commited on
Commit
4cb7900
1 Parent(s): 18f8119

Peft lotfq (#1222)

Browse files

* loftq support for lora

* fix loftq check

* update readme for loftq

* readability cleanup

* use peft main for loftq fixes, remove unnecessary special tokens

* remove unused test from older deprecation

README.md CHANGED
@@ -696,6 +696,12 @@ lora_modules_to_save:
696
 
697
  lora_fan_in_fan_out: false
698
 
 
 
 
 
 
 
699
  # ReLoRA configuration
700
  # Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
701
  relora_steps: # Number of steps per ReLoRA restart
 
696
 
697
  lora_fan_in_fan_out: false
698
 
699
+ peft:
700
+ # Configuration options for loftq initialization for LoRA
701
+ # https://huggingface.co/docs/peft/developer_guides/quantization#loftq-initialization
702
+ loftq_config:
703
+ loftq_bits: # typically 4 bits
704
+
705
  # ReLoRA configuration
706
  # Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
707
  relora_steps: # Number of steps per ReLoRA restart
examples/llama-2/fft_optimized.yml CHANGED
@@ -67,6 +67,3 @@ weight_decay: 0.1
67
  fsdp:
68
  fsdp_config:
69
  special_tokens:
70
- bos_token: "<s>"
71
- eos_token: "</s>"
72
- unk_token: "<unk>"
 
67
  fsdp:
68
  fsdp_config:
69
  special_tokens:
 
 
 
examples/llama-2/loftq.yml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: NousResearch/Llama-2-7b-hf
2
+ model_type: LlamaForCausalLM
3
+ tokenizer_type: LlamaTokenizer
4
+ is_llama_derived_model: true
5
+
6
+ load_in_8bit: false
7
+ load_in_4bit: false
8
+ strict: false
9
+
10
+ datasets:
11
+ - path: mhenrichsen/alpaca_2k_test
12
+ type: alpaca
13
+ dataset_prepared_path:
14
+ val_set_size: 0.05
15
+ output_dir: ./lora-out
16
+
17
+ sequence_len: 4096
18
+ sample_packing: true
19
+ pad_to_sequence_len: true
20
+
21
+ adapter: lora
22
+ lora_model_dir:
23
+ lora_r: 32
24
+ lora_alpha: 16
25
+ lora_dropout: 0.05
26
+ lora_target_linear: true
27
+ lora_fan_in_fan_out:
28
+ peft:
29
+ loftq_config:
30
+ loftq_bits: 4
31
+
32
+ wandb_project:
33
+ wandb_entity:
34
+ wandb_watch:
35
+ wandb_name:
36
+ wandb_log_model:
37
+
38
+ gradient_accumulation_steps: 4
39
+ micro_batch_size: 2
40
+ num_epochs: 4
41
+ optimizer: adamw_bnb_8bit
42
+ lr_scheduler: cosine
43
+ learning_rate: 0.0002
44
+
45
+ train_on_inputs: false
46
+ group_by_length: false
47
+ bf16: auto
48
+ fp16:
49
+ tf32: false
50
+
51
+ gradient_checkpointing: true
52
+ early_stopping_patience:
53
+ resume_from_checkpoint:
54
+ local_rank:
55
+ logging_steps: 1
56
+ xformers_attention:
57
+ flash_attention: true
58
+ s2_attention:
59
+
60
+ warmup_steps: 10
61
+ evals_per_epoch: 4
62
+ eval_table_size:
63
+ eval_table_max_new_tokens: 128
64
+ saves_per_epoch: 1
65
+ debug:
66
+ deepspeed:
67
+ weight_decay: 0.0
68
+ fsdp:
69
+ fsdp_config:
70
+ special_tokens:
examples/llama-2/lora.yml CHANGED
@@ -65,6 +65,3 @@ weight_decay: 0.0
65
  fsdp:
66
  fsdp_config:
67
  special_tokens:
68
- bos_token: "<s>"
69
- eos_token: "</s>"
70
- unk_token: "<unk>"
 
65
  fsdp:
66
  fsdp_config:
67
  special_tokens:
 
 
 
examples/llama-2/qlora.yml CHANGED
@@ -65,6 +65,3 @@ weight_decay: 0.0
65
  fsdp:
66
  fsdp_config:
67
  special_tokens:
68
- bos_token: "<s>"
69
- eos_token: "</s>"
70
- unk_token: "<unk>"
 
65
  fsdp:
66
  fsdp_config:
67
  special_tokens:
 
 
 
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
2
  packaging==23.2
3
- peft==0.7.1
4
  transformers==4.37.0
5
  tokenizers==0.15.0
6
  bitsandbytes>=0.41.1
 
1
  --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
2
  packaging==23.2
3
+ peft @ git+https://github.com/huggingface/peft.git
4
  transformers==4.37.0
5
  tokenizers==0.15.0
6
  bitsandbytes>=0.41.1
src/axolotl/utils/config.py CHANGED
@@ -232,9 +232,6 @@ def validate_config(cfg):
232
  "eval_batch_size != micro_batch_size. This can lead to VRAM instability."
233
  )
234
 
235
- if cfg.load_4bit:
236
- raise ValueError("cfg.load_4bit parameter has been deprecated")
237
-
238
  if cfg.adapter == "qlora":
239
  if cfg.merge_lora:
240
  # can't merge qlora if loaded in 8bit or 4bit
@@ -260,7 +257,8 @@ def validate_config(cfg):
260
  if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
261
  raise ValueError("Fused modules are not supported with QLoRA")
262
 
263
- if not cfg.load_in_8bit and cfg.adapter == "lora":
 
264
  LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
265
 
266
  if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
 
232
  "eval_batch_size != micro_batch_size. This can lead to VRAM instability."
233
  )
234
 
 
 
 
235
  if cfg.adapter == "qlora":
236
  if cfg.merge_lora:
237
  # can't merge qlora if loaded in 8bit or 4bit
 
257
  if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
258
  raise ValueError("Fused modules are not supported with QLoRA")
259
 
260
+ loftq = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
261
+ if not cfg.load_in_8bit and cfg.adapter == "lora" and not loftq:
262
  LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
263
 
264
  if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
src/axolotl/utils/models.py CHANGED
@@ -9,7 +9,7 @@ import bitsandbytes as bnb
9
  import torch
10
  import transformers
11
  from optimum.bettertransformer import BetterTransformer
12
- from peft import PeftConfig, prepare_model_for_kbit_training
13
  from peft.tuners.lora import QuantLinear
14
  from transformers import ( # noqa: F401
15
  AddedToken,
@@ -667,13 +667,17 @@ def load_model(
667
  # Qwen doesn't play nicely with LoRA if this is enabled
668
  skip_prepare_model_for_kbit_training = True
669
 
670
- if (cfg.adapter == "lora" and load_in_8bit) or (
671
- cfg.adapter == "qlora" and cfg.load_in_4bit
672
- ):
673
- LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
 
674
  if cfg.gradient_checkpointing:
675
  model.gradient_checkpointing_enable()
676
- if not skip_prepare_model_for_kbit_training:
 
 
 
677
  model = prepare_model_for_kbit_training(
678
  model, use_gradient_checkpointing=cfg.gradient_checkpointing
679
  )
@@ -700,6 +704,7 @@ def load_model(
700
  model, lora_config = load_adapter(model, cfg, cfg.adapter)
701
 
702
  if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit):
 
703
  model.to(f"cuda:{cfg.local_rank}")
704
 
705
  if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
@@ -797,6 +802,12 @@ def load_lora(model, cfg, inference=False, config_only=False):
797
  LOG.info(f"found linear modules: {repr(linear_names)}")
798
  lora_target_modules = list(set(lora_target_modules + linear_names))
799
 
 
 
 
 
 
 
800
  lora_config = LoraConfig(
801
  r=cfg.lora_r,
802
  lora_alpha=cfg.lora_alpha,
@@ -807,6 +818,7 @@ def load_lora(model, cfg, inference=False, config_only=False):
807
  modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
808
  bias="none",
809
  task_type="CAUSAL_LM",
 
810
  )
811
 
812
  if config_only:
 
9
  import torch
10
  import transformers
11
  from optimum.bettertransformer import BetterTransformer
12
+ from peft import LoftQConfig, PeftConfig, prepare_model_for_kbit_training
13
  from peft.tuners.lora import QuantLinear
14
  from transformers import ( # noqa: F401
15
  AddedToken,
 
667
  # Qwen doesn't play nicely with LoRA if this is enabled
668
  skip_prepare_model_for_kbit_training = True
669
 
670
+ loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
671
+ if cfg.adapter == "lora" and loftq_bits:
672
+ skip_prepare_model_for_kbit_training = True
673
+
674
+ if cfg.adapter in ["lora", "qlora"]:
675
  if cfg.gradient_checkpointing:
676
  model.gradient_checkpointing_enable()
677
+ if (
678
+ cfg.load_in_8bit or cfg.load_in_4bit
679
+ ) and not skip_prepare_model_for_kbit_training:
680
+ LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
681
  model = prepare_model_for_kbit_training(
682
  model, use_gradient_checkpointing=cfg.gradient_checkpointing
683
  )
 
704
  model, lora_config = load_adapter(model, cfg, cfg.adapter)
705
 
706
  if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit):
707
+ # TODO revaldate this conditional
708
  model.to(f"cuda:{cfg.local_rank}")
709
 
710
  if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
 
802
  LOG.info(f"found linear modules: {repr(linear_names)}")
803
  lora_target_modules = list(set(lora_target_modules + linear_names))
804
 
805
+ lora_config_kwargs = {}
806
+ loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
807
+ if loftq_bits:
808
+ lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
809
+ lora_config_kwargs["init_lora_weights"] = "loftq"
810
+
811
  lora_config = LoraConfig(
812
  r=cfg.lora_r,
813
  lora_alpha=cfg.lora_alpha,
 
818
  modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
819
  bias="none",
820
  task_type="CAUSAL_LM",
821
+ **lora_config_kwargs,
822
  )
823
 
824
  if config_only:
tests/test_validation.py CHANGED
@@ -32,16 +32,6 @@ class ValidationTest(BaseValidation):
32
  Test the validation module
33
  """
34
 
35
- def test_load_4bit_deprecate(self):
36
- cfg = DictDefault(
37
- {
38
- "load_4bit": True,
39
- }
40
- )
41
-
42
- with pytest.raises(ValueError):
43
- validate_config(cfg)
44
-
45
  def test_batch_size_unused_warning(self):
46
  cfg = DictDefault(
47
  {
 
32
  Test the validation module
33
  """
34
 
 
 
 
 
 
 
 
 
 
 
35
  def test_batch_size_unused_warning(self):
36
  cfg = DictDefault(
37
  {