more checks and fixes for deepspeed and fsdp (#1208) [skip ci]
Browse files- deepspeed_configs/zero1.json +0 -9
- deepspeed_configs/zero2.json +0 -9
- deepspeed_configs/zero3.json +0 -9
- deepspeed_configs/zero3_bf16.json +0 -9
- src/axolotl/utils/config.py +28 -20
- src/axolotl/utils/models.py +10 -8
deepspeed_configs/zero1.json
CHANGED
|
@@ -15,15 +15,6 @@
|
|
| 15 |
"hysteresis": 2,
|
| 16 |
"min_loss_scale": 1
|
| 17 |
},
|
| 18 |
-
"optimizer": {
|
| 19 |
-
"type": "AdamW",
|
| 20 |
-
"params": {
|
| 21 |
-
"lr": "auto",
|
| 22 |
-
"betas": "auto",
|
| 23 |
-
"eps": "auto",
|
| 24 |
-
"weight_decay": "auto"
|
| 25 |
-
}
|
| 26 |
-
},
|
| 27 |
"gradient_accumulation_steps": "auto",
|
| 28 |
"train_batch_size": "auto",
|
| 29 |
"train_micro_batch_size_per_gpu": "auto",
|
|
|
|
| 15 |
"hysteresis": 2,
|
| 16 |
"min_loss_scale": 1
|
| 17 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
"gradient_accumulation_steps": "auto",
|
| 19 |
"train_batch_size": "auto",
|
| 20 |
"train_micro_batch_size_per_gpu": "auto",
|
deepspeed_configs/zero2.json
CHANGED
|
@@ -19,15 +19,6 @@
|
|
| 19 |
"hysteresis": 2,
|
| 20 |
"min_loss_scale": 1
|
| 21 |
},
|
| 22 |
-
"optimizer": {
|
| 23 |
-
"type": "AdamW",
|
| 24 |
-
"params": {
|
| 25 |
-
"lr": "auto",
|
| 26 |
-
"betas": "auto",
|
| 27 |
-
"eps": "auto",
|
| 28 |
-
"weight_decay": "auto"
|
| 29 |
-
}
|
| 30 |
-
},
|
| 31 |
"gradient_accumulation_steps": "auto",
|
| 32 |
"train_batch_size": "auto",
|
| 33 |
"train_micro_batch_size_per_gpu": "auto",
|
|
|
|
| 19 |
"hysteresis": 2,
|
| 20 |
"min_loss_scale": 1
|
| 21 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
"gradient_accumulation_steps": "auto",
|
| 23 |
"train_batch_size": "auto",
|
| 24 |
"train_micro_batch_size_per_gpu": "auto",
|
deepspeed_configs/zero3.json
CHANGED
|
@@ -23,15 +23,6 @@
|
|
| 23 |
"hysteresis": 2,
|
| 24 |
"min_loss_scale": 1
|
| 25 |
},
|
| 26 |
-
"optimizer": {
|
| 27 |
-
"type": "AdamW",
|
| 28 |
-
"params": {
|
| 29 |
-
"lr": "auto",
|
| 30 |
-
"betas": "auto",
|
| 31 |
-
"eps": "auto",
|
| 32 |
-
"weight_decay": "auto"
|
| 33 |
-
}
|
| 34 |
-
},
|
| 35 |
"gradient_accumulation_steps": "auto",
|
| 36 |
"train_batch_size": "auto",
|
| 37 |
"train_micro_batch_size_per_gpu": "auto",
|
|
|
|
| 23 |
"hysteresis": 2,
|
| 24 |
"min_loss_scale": 1
|
| 25 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
"gradient_accumulation_steps": "auto",
|
| 27 |
"train_batch_size": "auto",
|
| 28 |
"train_micro_batch_size_per_gpu": "auto",
|
deepspeed_configs/zero3_bf16.json
CHANGED
|
@@ -23,15 +23,6 @@
|
|
| 23 |
"hysteresis": 2,
|
| 24 |
"min_loss_scale": 1
|
| 25 |
},
|
| 26 |
-
"optimizer": {
|
| 27 |
-
"type": "AdamW",
|
| 28 |
-
"params": {
|
| 29 |
-
"lr": "auto",
|
| 30 |
-
"betas": "auto",
|
| 31 |
-
"eps": "auto",
|
| 32 |
-
"weight_decay": "auto"
|
| 33 |
-
}
|
| 34 |
-
},
|
| 35 |
"gradient_accumulation_steps": "auto",
|
| 36 |
"train_batch_size": "auto",
|
| 37 |
"train_micro_batch_size_per_gpu": "auto",
|
|
|
|
| 23 |
"hysteresis": 2,
|
| 24 |
"min_loss_scale": 1
|
| 25 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
"gradient_accumulation_steps": "auto",
|
| 27 |
"train_batch_size": "auto",
|
| 28 |
"train_micro_batch_size_per_gpu": "auto",
|
src/axolotl/utils/config.py
CHANGED
|
@@ -95,7 +95,7 @@ def normalize_config(cfg):
|
|
| 95 |
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
| 96 |
if save_steps < 1.0: # prevent saves on every step
|
| 97 |
cfg.save_steps = save_steps
|
| 98 |
-
if cfg.evals_per_epoch:
|
| 99 |
eval_steps = 1.0 / (cfg.evals_per_epoch * cfg.num_epochs)
|
| 100 |
if eval_steps < 1.0: # prevent evals on every step
|
| 101 |
cfg.eval_steps = eval_steps
|
|
@@ -485,35 +485,43 @@ def validate_config(cfg):
|
|
| 485 |
"`use_reentrant` must be false when used with partially frozen model."
|
| 486 |
)
|
| 487 |
|
| 488 |
-
if cfg.
|
| 489 |
with open(cfg.deepspeed, encoding="utf-8") as file:
|
| 490 |
contents = file.read()
|
| 491 |
deepspeed_cfg: DictDefault = DictDefault(json.loads(contents))
|
| 492 |
-
if
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
if not (
|
| 497 |
-
(
|
| 498 |
-
deepspeed_cfg.bf16
|
| 499 |
-
and deepspeed_cfg.bf16.enabled # pylint: disable=no-member
|
| 500 |
-
is True
|
| 501 |
-
)
|
| 502 |
-
or (
|
| 503 |
-
deepspeed_cfg.fp16
|
| 504 |
-
and deepspeed_cfg.fp16.enabled # pylint: disable=no-member
|
| 505 |
-
is True
|
| 506 |
-
)
|
| 507 |
):
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 511 |
|
| 512 |
if cfg.test_datasets and cfg.val_set_size:
|
| 513 |
raise ValueError(
|
| 514 |
"non-zero val_set_size should not be used with test_datasets configuration"
|
| 515 |
)
|
| 516 |
|
|
|
|
|
|
|
|
|
|
| 517 |
# TODO
|
| 518 |
# MPT 7b
|
| 519 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
|
|
|
| 95 |
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
| 96 |
if save_steps < 1.0: # prevent saves on every step
|
| 97 |
cfg.save_steps = save_steps
|
| 98 |
+
if (cfg.val_set_size or cfg.test_datasets) and cfg.evals_per_epoch:
|
| 99 |
eval_steps = 1.0 / (cfg.evals_per_epoch * cfg.num_epochs)
|
| 100 |
if eval_steps < 1.0: # prevent evals on every step
|
| 101 |
cfg.eval_steps = eval_steps
|
|
|
|
| 485 |
"`use_reentrant` must be false when used with partially frozen model."
|
| 486 |
)
|
| 487 |
|
| 488 |
+
if cfg.deepspeed and Path(cfg.deepspeed).is_file():
|
| 489 |
with open(cfg.deepspeed, encoding="utf-8") as file:
|
| 490 |
contents = file.read()
|
| 491 |
deepspeed_cfg: DictDefault = DictDefault(json.loads(contents))
|
| 492 |
+
if cfg.flash_attention:
|
| 493 |
+
if (
|
| 494 |
+
deepspeed_cfg.zero_optimization
|
| 495 |
+
and deepspeed_cfg.zero_optimization.stage == 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
):
|
| 497 |
+
if not (
|
| 498 |
+
(
|
| 499 |
+
deepspeed_cfg.bf16
|
| 500 |
+
and deepspeed_cfg.bf16.enabled # pylint: disable=no-member
|
| 501 |
+
is True
|
| 502 |
+
)
|
| 503 |
+
or (
|
| 504 |
+
deepspeed_cfg.fp16
|
| 505 |
+
and deepspeed_cfg.fp16.enabled # pylint: disable=no-member
|
| 506 |
+
is True
|
| 507 |
+
)
|
| 508 |
+
):
|
| 509 |
+
raise ValueError(
|
| 510 |
+
"bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention"
|
| 511 |
+
)
|
| 512 |
+
if "8bit" in cfg.optimizer and deepspeed_cfg.optimizer:
|
| 513 |
+
LOG.warning(
|
| 514 |
+
f"conflicting optimizer: {cfg.optimizer} used alongside deepspeed optimizer."
|
| 515 |
+
)
|
| 516 |
|
| 517 |
if cfg.test_datasets and cfg.val_set_size:
|
| 518 |
raise ValueError(
|
| 519 |
"non-zero val_set_size should not be used with test_datasets configuration"
|
| 520 |
)
|
| 521 |
|
| 522 |
+
if cfg.fsdp and "bnb" in cfg.optimizer:
|
| 523 |
+
raise ValueError(f"FSDP not compatible with {cfg.optimizer}")
|
| 524 |
+
|
| 525 |
# TODO
|
| 526 |
# MPT 7b
|
| 527 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
src/axolotl/utils/models.py
CHANGED
|
@@ -642,15 +642,17 @@ def load_model(
|
|
| 642 |
|
| 643 |
# make sure these are fp32 per Ramesh et al. (2021)
|
| 644 |
embedding_modules = get_linear_embedding_layers(cfg.model_config_type)
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
# don't upcast lm_head for btlm
|
| 650 |
-
continue
|
| 651 |
-
if any(m in name for m in embedding_modules):
|
| 652 |
-
if hasattr(module, "weight"):
|
| 653 |
module.to(torch.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 654 |
|
| 655 |
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
| 656 |
skip_prepare_model_for_kbit_training = False
|
|
|
|
| 642 |
|
| 643 |
# make sure these are fp32 per Ramesh et al. (2021)
|
| 644 |
embedding_modules = get_linear_embedding_layers(cfg.model_config_type)
|
| 645 |
+
if not cfg.fsdp:
|
| 646 |
+
# FSDP doesn't like mixed Float and BFloat16
|
| 647 |
+
for name, module in model.named_modules():
|
| 648 |
+
if any(m in name for m in ["norm", "gate"]):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 649 |
module.to(torch.float32)
|
| 650 |
+
if model_config.model_type == "btlm":
|
| 651 |
+
# don't upcast lm_head for btlm
|
| 652 |
+
continue
|
| 653 |
+
if any(m in name for m in embedding_modules):
|
| 654 |
+
if hasattr(module, "weight"):
|
| 655 |
+
module.to(torch.float32)
|
| 656 |
|
| 657 |
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
| 658 |
skip_prepare_model_for_kbit_training = False
|