Peft deepspeed resume (#1227)
Browse files* import deepspeed integration
* monkeypatch peft adapater with deepspeed for resume from checkpoint
* fix patch
* fix patches attempt 2
* make sure to set lora_model_dir
* skip pylint for deepspeed.utils
* pick up upstream fix in transformers
* remove monkeypatch for deepspeed/peft fix
* no need to set the lora_model_dir on resume
* unset load_in_*bit when using quant config
* guard before del
* better handling of load_in* kwargs
- requirements.txt +1 -1
- src/axolotl/cli/train.py +4 -3
- src/axolotl/train.py +15 -15
- src/axolotl/utils/models.py +15 -7
requirements.txt
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
| 2 |
packaging==23.2
|
| 3 |
peft @ git+https://github.com/huggingface/peft.git
|
| 4 |
-
transformers
|
| 5 |
tokenizers==0.15.0
|
| 6 |
bitsandbytes>=0.41.1
|
| 7 |
accelerate==0.26.1
|
|
|
|
| 1 |
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
| 2 |
packaging==23.2
|
| 3 |
peft @ git+https://github.com/huggingface/peft.git
|
| 4 |
+
transformers @ git+https://github.com/huggingface/transformers.git@bebeeee01275c32fccec3fa36d8b148d3813a7dc
|
| 5 |
tokenizers==0.15.0
|
| 6 |
bitsandbytes>=0.41.1
|
| 7 |
accelerate==0.26.1
|
src/axolotl/cli/train.py
CHANGED
|
@@ -6,8 +6,9 @@ from pathlib import Path
|
|
| 6 |
from typing import Tuple
|
| 7 |
|
| 8 |
import fire
|
| 9 |
-
import
|
| 10 |
-
from transformers import PreTrainedModel
|
|
|
|
| 11 |
|
| 12 |
from axolotl.cli import (
|
| 13 |
check_accelerate_default_config,
|
|
@@ -27,7 +28,7 @@ LOG = logging.getLogger("axolotl.cli.train")
|
|
| 27 |
def do_cli(config: Path = Path("examples/"), **kwargs):
|
| 28 |
# pylint: disable=duplicate-code
|
| 29 |
parsed_cfg = load_cfg(config, **kwargs)
|
| 30 |
-
parser =
|
| 31 |
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
| 32 |
return_remaining_strings=True
|
| 33 |
)
|
|
|
|
| 6 |
from typing import Tuple
|
| 7 |
|
| 8 |
import fire
|
| 9 |
+
from transformers.hf_argparser import HfArgumentParser
|
| 10 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 11 |
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
| 12 |
|
| 13 |
from axolotl.cli import (
|
| 14 |
check_accelerate_default_config,
|
|
|
|
| 28 |
def do_cli(config: Path = Path("examples/"), **kwargs):
|
| 29 |
# pylint: disable=duplicate-code
|
| 30 |
parsed_cfg = load_cfg(config, **kwargs)
|
| 31 |
+
parser = HfArgumentParser((TrainerCliArgs))
|
| 32 |
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
| 33 |
return_remaining_strings=True
|
| 34 |
)
|
src/axolotl/train.py
CHANGED
|
@@ -57,6 +57,21 @@ def train(
|
|
| 57 |
eval_dataset = dataset_meta.eval_dataset
|
| 58 |
total_num_steps = dataset_meta.total_num_steps
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
# Load the model and tokenizer
|
| 61 |
msg = "loading model"
|
| 62 |
if cfg.adapter:
|
|
@@ -79,21 +94,6 @@ def train(
|
|
| 79 |
|
| 80 |
safe_serialization = cfg.save_safetensors is True
|
| 81 |
|
| 82 |
-
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
| 83 |
-
possible_checkpoints = [
|
| 84 |
-
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
| 85 |
-
]
|
| 86 |
-
if len(possible_checkpoints) > 0:
|
| 87 |
-
sorted_paths = sorted(
|
| 88 |
-
possible_checkpoints,
|
| 89 |
-
key=lambda path: int(path.split("-")[-1]),
|
| 90 |
-
)
|
| 91 |
-
cfg.resume_from_checkpoint = sorted_paths[-1]
|
| 92 |
-
LOG.info(
|
| 93 |
-
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
| 94 |
-
)
|
| 95 |
-
resume_from_checkpoint = cfg.resume_from_checkpoint
|
| 96 |
-
|
| 97 |
if cfg.unfrozen_parameters:
|
| 98 |
freeze_parameters_except(model, cfg.unfrozen_parameters)
|
| 99 |
|
|
|
|
| 57 |
eval_dataset = dataset_meta.eval_dataset
|
| 58 |
total_num_steps = dataset_meta.total_num_steps
|
| 59 |
|
| 60 |
+
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
| 61 |
+
possible_checkpoints = [
|
| 62 |
+
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
| 63 |
+
]
|
| 64 |
+
if len(possible_checkpoints) > 0:
|
| 65 |
+
sorted_paths = sorted(
|
| 66 |
+
possible_checkpoints,
|
| 67 |
+
key=lambda path: int(path.split("-")[-1]),
|
| 68 |
+
)
|
| 69 |
+
cfg.resume_from_checkpoint = sorted_paths[-1]
|
| 70 |
+
LOG.info(
|
| 71 |
+
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
| 72 |
+
)
|
| 73 |
+
resume_from_checkpoint = cfg.resume_from_checkpoint
|
| 74 |
+
|
| 75 |
# Load the model and tokenizer
|
| 76 |
msg = "loading model"
|
| 77 |
if cfg.adapter:
|
|
|
|
| 94 |
|
| 95 |
safe_serialization = cfg.save_safetensors is True
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
if cfg.unfrozen_parameters:
|
| 98 |
freeze_parameters_except(model, cfg.unfrozen_parameters)
|
| 99 |
|
src/axolotl/utils/models.py
CHANGED
|
@@ -473,6 +473,18 @@ def load_model(
|
|
| 473 |
**bnb_config,
|
| 474 |
)
|
| 475 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
# sample packing uses custom FA2 patch
|
| 477 |
if cfg.flash_attention:
|
| 478 |
if not cfg.sample_packing:
|
|
@@ -506,8 +518,6 @@ def load_model(
|
|
| 506 |
model = LlamaForCausalLM.from_pretrained(
|
| 507 |
base_model,
|
| 508 |
config=model_config,
|
| 509 |
-
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 510 |
-
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 511 |
**model_kwargs,
|
| 512 |
)
|
| 513 |
|
|
@@ -575,8 +585,6 @@ def load_model(
|
|
| 575 |
model = getattr(transformers, model_type).from_pretrained(
|
| 576 |
base_model,
|
| 577 |
config=model_config,
|
| 578 |
-
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 579 |
-
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 580 |
trust_remote_code=cfg.trust_remote_code or False,
|
| 581 |
**model_kwargs,
|
| 582 |
)
|
|
@@ -608,8 +616,6 @@ def load_model(
|
|
| 608 |
model = AutoModelForCausalLM.from_pretrained(
|
| 609 |
base_model,
|
| 610 |
config=model_config,
|
| 611 |
-
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 612 |
-
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 613 |
trust_remote_code=cfg.trust_remote_code or False,
|
| 614 |
**model_kwargs,
|
| 615 |
)
|
|
@@ -678,7 +684,9 @@ def load_model(
|
|
| 678 |
skip_prepare_model_for_kbit_training = False
|
| 679 |
|
| 680 |
if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled():
|
| 681 |
-
from deepspeed.utils import
|
|
|
|
|
|
|
| 682 |
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
| 683 |
|
| 684 |
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
|
|
|
|
| 473 |
**bnb_config,
|
| 474 |
)
|
| 475 |
|
| 476 |
+
if cfg.load_in_8bit and cfg.adapter is not None:
|
| 477 |
+
model_kwargs["load_in_8bit"] = True
|
| 478 |
+
if cfg.load_in_4bit and cfg.adapter is not None:
|
| 479 |
+
model_kwargs["load_in_4bit"] = True
|
| 480 |
+
|
| 481 |
+
# no longer needed per https://github.com/huggingface/transformers/pull/26610
|
| 482 |
+
if "quantization_config" in model_kwargs or cfg.gptq:
|
| 483 |
+
if "load_in_8bit" in model_kwargs:
|
| 484 |
+
del model_kwargs["load_in_8bit"]
|
| 485 |
+
if "load_in_4bit" in model_kwargs:
|
| 486 |
+
del model_kwargs["load_in_4bit"]
|
| 487 |
+
|
| 488 |
# sample packing uses custom FA2 patch
|
| 489 |
if cfg.flash_attention:
|
| 490 |
if not cfg.sample_packing:
|
|
|
|
| 518 |
model = LlamaForCausalLM.from_pretrained(
|
| 519 |
base_model,
|
| 520 |
config=model_config,
|
|
|
|
|
|
|
| 521 |
**model_kwargs,
|
| 522 |
)
|
| 523 |
|
|
|
|
| 585 |
model = getattr(transformers, model_type).from_pretrained(
|
| 586 |
base_model,
|
| 587 |
config=model_config,
|
|
|
|
|
|
|
| 588 |
trust_remote_code=cfg.trust_remote_code or False,
|
| 589 |
**model_kwargs,
|
| 590 |
)
|
|
|
|
| 616 |
model = AutoModelForCausalLM.from_pretrained(
|
| 617 |
base_model,
|
| 618 |
config=model_config,
|
|
|
|
|
|
|
| 619 |
trust_remote_code=cfg.trust_remote_code or False,
|
| 620 |
**model_kwargs,
|
| 621 |
)
|
|
|
|
| 684 |
skip_prepare_model_for_kbit_training = False
|
| 685 |
|
| 686 |
if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled():
|
| 687 |
+
from deepspeed.utils import ( # pylint: disable=no-name-in-module
|
| 688 |
+
set_z3_leaf_modules,
|
| 689 |
+
)
|
| 690 |
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
| 691 |
|
| 692 |
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
|