Thytu commited on
Commit
dd00657
1 Parent(s): c3d2562

refactor(param): rename load_4bit config param by gptq

Browse files
README.md CHANGED
@@ -176,7 +176,7 @@ tokenizer_type: AutoTokenizer
176
  trust_remote_code:
177
 
178
  # whether you are training a 4-bit GPTQ quantized model
179
- load_4bit: true
180
  gptq_groupsize: 128 # group size
181
  gptq_model_v1: false # v1 or v2
182
 
 
176
  trust_remote_code:
177
 
178
  # whether you are training a 4-bit GPTQ quantized model
179
+ gptq: true
180
  gptq_groupsize: 128 # group size
181
  gptq_model_v1: false # v1 or v2
182
 
configs/quickstart.yml CHANGED
@@ -40,6 +40,6 @@ early_stopping_patience: 3
40
  resume_from_checkpoint:
41
  auto_resume_from_checkpoints: true
42
  local_rank:
43
- load_4bit: true
44
  xformers_attention: true
45
  flash_attention:
 
40
  resume_from_checkpoint:
41
  auto_resume_from_checkpoints: true
42
  local_rank:
43
+ gptq: true
44
  xformers_attention: true
45
  flash_attention:
examples/4bit-lora-7b/config.yml CHANGED
@@ -4,7 +4,7 @@ model_type: LlamaForCausalLM
4
  tokenizer_type: LlamaTokenizer
5
  trust_remote_code:
6
  load_in_8bit: true
7
- load_4bit: true
8
  datasets:
9
  - path: vicgalle/alpaca-gpt4
10
  type: alpaca
 
4
  tokenizer_type: LlamaTokenizer
5
  trust_remote_code:
6
  load_in_8bit: true
7
+ gptq: true
8
  datasets:
9
  - path: vicgalle/alpaca-gpt4
10
  type: alpaca
src/axolotl/utils/models.py CHANGED
@@ -73,7 +73,7 @@ def load_model(
73
  else:
74
  torch_dtype = torch.float32
75
  try:
76
- if cfg.load_4bit:
77
  from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
78
  replace_peft_model_with_int4_lora_model,
79
  )
@@ -95,7 +95,7 @@ def load_model(
95
  bnb_4bit_quant_type="nf4",
96
  )
97
  try:
98
- if cfg.load_4bit and is_llama_derived_model:
99
  from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
100
  from huggingface_hub import snapshot_download
101
 
@@ -248,7 +248,7 @@ def load_model(
248
 
249
  if (
250
  ((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora")
251
- and not cfg.load_4bit
252
  and (load_in_8bit or cfg.load_in_4bit)
253
  ):
254
  logging.info("converting PEFT model w/ prepare_model_for_int8_training")
@@ -259,7 +259,7 @@ def load_model(
259
  if cfg.ddp and not load_in_8bit:
260
  model.to(f"cuda:{cfg.local_rank}")
261
 
262
- if cfg.load_4bit:
263
  # Scales to half
264
  logging.info("Fitting 4bit scales and zeros to half")
265
  for n, m in model.named_modules():
@@ -274,7 +274,7 @@ def load_model(
274
  if (
275
  torch.cuda.device_count() > 1
276
  and int(os.getenv("WORLD_SIZE", "1")) > 1
277
- and cfg.load_4bit
278
  ):
279
  # llama is PROBABLY model parallelizable, but the default isn't that it is
280
  # so let's only set it for the 4bit, see
 
73
  else:
74
  torch_dtype = torch.float32
75
  try:
76
+ if cfg.gptq:
77
  from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
78
  replace_peft_model_with_int4_lora_model,
79
  )
 
95
  bnb_4bit_quant_type="nf4",
96
  )
97
  try:
98
+ if cfg.gptq and is_llama_derived_model:
99
  from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
100
  from huggingface_hub import snapshot_download
101
 
 
248
 
249
  if (
250
  ((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora")
251
+ and not cfg.gptq
252
  and (load_in_8bit or cfg.load_in_4bit)
253
  ):
254
  logging.info("converting PEFT model w/ prepare_model_for_int8_training")
 
259
  if cfg.ddp and not load_in_8bit:
260
  model.to(f"cuda:{cfg.local_rank}")
261
 
262
+ if cfg.gptq:
263
  # Scales to half
264
  logging.info("Fitting 4bit scales and zeros to half")
265
  for n, m in model.named_modules():
 
274
  if (
275
  torch.cuda.device_count() > 1
276
  and int(os.getenv("WORLD_SIZE", "1")) > 1
277
+ and cfg.gptq
278
  ):
279
  # llama is PROBABLY model parallelizable, but the default isn't that it is
280
  # so let's only set it for the 4bit, see
src/axolotl/utils/trainer.py CHANGED
@@ -63,7 +63,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
63
  training_arguments_kwargs["warmup_steps"] = warmup_steps
64
  training_arguments_kwargs["logging_steps"] = logging_steps
65
  if cfg.gradient_checkpointing is not None:
66
- if cfg.load_4bit:
67
  from alpaca_lora_4bit.gradient_checkpointing import (
68
  apply_gradient_checkpointing,
69
  )
@@ -138,7 +138,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
138
  importlib.import_module("torchdistx")
139
  if (
140
  cfg.optimizer == "adamw_bnb_8bit"
141
- and not cfg.load_4bit
142
  and not "deepspeed" in training_arguments_kwargs
143
  and not cfg.fsdp
144
  ):
 
63
  training_arguments_kwargs["warmup_steps"] = warmup_steps
64
  training_arguments_kwargs["logging_steps"] = logging_steps
65
  if cfg.gradient_checkpointing is not None:
66
+ if cfg.gptq:
67
  from alpaca_lora_4bit.gradient_checkpointing import (
68
  apply_gradient_checkpointing,
69
  )
 
138
  importlib.import_module("torchdistx")
139
  if (
140
  cfg.optimizer == "adamw_bnb_8bit"
141
+ and not cfg.gptq
142
  and not "deepspeed" in training_arguments_kwargs
143
  and not cfg.fsdp
144
  ):
src/axolotl/utils/validation.py CHANGED
@@ -2,16 +2,20 @@ import logging
2
 
3
 
4
  def validate_config(cfg):
 
 
 
5
  if cfg.adapter == "qlora":
6
  if cfg.merge_lora:
7
  # can't merge qlora if loaded in 8bit or 4bit
8
  assert cfg.load_in_8bit is False
9
- assert cfg.load_4bit is False
10
  assert cfg.load_in_4bit is False
11
  else:
12
  assert cfg.load_in_8bit is False
13
- assert cfg.load_4bit is False
14
  assert cfg.load_in_4bit is True
 
15
  if not cfg.load_in_8bit and cfg.adapter == "lora":
16
  logging.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
17
 
 
2
 
3
 
4
  def validate_config(cfg):
5
+ if cfg.load_4bit:
6
+ raise ValueError("cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq")
7
+
8
  if cfg.adapter == "qlora":
9
  if cfg.merge_lora:
10
  # can't merge qlora if loaded in 8bit or 4bit
11
  assert cfg.load_in_8bit is False
12
+ assert cfg.gptq is False
13
  assert cfg.load_in_4bit is False
14
  else:
15
  assert cfg.load_in_8bit is False
16
+ assert cfg.gptq is False
17
  assert cfg.load_in_4bit is True
18
+
19
  if not cfg.load_in_8bit and cfg.adapter == "lora":
20
  logging.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
21