fdsp config dict fix, todo list, add torchdistx support
Browse files- TODO.md +10 -0
- src/axolotl/utils/models.py +5 -0
- src/axolotl/utils/trainer.py +9 -3
TODO.md
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# todo list
|
2 |
+
|
3 |
+
- [] Validation of parameters for combinations that won't work
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
## things that are known not to work
|
8 |
+
|
9 |
+
- FSDP offload and gradient_checkpointing - https://github.com/pytorch/pytorch/issues/82203
|
10 |
+
- adamw_bnb_8bit doesn't play well with FSDP offload
|
src/axolotl/utils/models.py
CHANGED
@@ -179,6 +179,11 @@ def load_model(
|
|
179 |
m.scales = m.scales.half()
|
180 |
m.bias = m.bias.half()
|
181 |
|
|
|
|
|
|
|
|
|
|
|
182 |
# TODO resume_from_checkpoint handling
|
183 |
return model, tokenizer, lora_config
|
184 |
|
|
|
179 |
m.scales = m.scales.half()
|
180 |
m.bias = m.bias.half()
|
181 |
|
182 |
+
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) > 1:
|
183 |
+
model.is_parallelizable = True
|
184 |
+
model.model_parallel = True
|
185 |
+
|
186 |
+
|
187 |
# TODO resume_from_checkpoint handling
|
188 |
return model, tokenizer, lora_config
|
189 |
|
src/axolotl/utils/trainer.py
CHANGED
@@ -1,5 +1,7 @@
|
|
|
|
1 |
import math
|
2 |
import os
|
|
|
3 |
from pathlib import Path
|
4 |
|
5 |
import bitsandbytes as bnb
|
@@ -35,9 +37,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
35 |
else:
|
36 |
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
|
37 |
if cfg.fsdp:
|
38 |
-
training_arguments_kwargs["fsdp"] = cfg.fsdp
|
39 |
-
if cfg.
|
40 |
-
training_arguments_kwargs["
|
41 |
|
42 |
|
43 |
# deepspeed
|
@@ -73,6 +75,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
73 |
|
74 |
trainer_kwargs = {}
|
75 |
|
|
|
|
|
|
|
|
|
76 |
if cfg.optimizer == "adam8bit" and not cfg.load_4bit and not "deepspeed" in training_arguments_kwargs:
|
77 |
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
78 |
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
|
|
1 |
+
import importlib
|
2 |
import math
|
3 |
import os
|
4 |
+
import sys
|
5 |
from pathlib import Path
|
6 |
|
7 |
import bitsandbytes as bnb
|
|
|
37 |
else:
|
38 |
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
|
39 |
if cfg.fsdp:
|
40 |
+
training_arguments_kwargs["fsdp"] = cfg.fsdp
|
41 |
+
if cfg.fsdp_config:
|
42 |
+
training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
|
43 |
|
44 |
|
45 |
# deepspeed
|
|
|
75 |
|
76 |
trainer_kwargs = {}
|
77 |
|
78 |
+
if cfg.optimizer == "adamw_anyprecision":
|
79 |
+
if Path(cfg.torchdistx_path).exists():
|
80 |
+
sys.path.append(cfg.torchdistx_path)
|
81 |
+
torchdistx = importlib.import_module('torchdistx')
|
82 |
if cfg.optimizer == "adam8bit" and not cfg.load_4bit and not "deepspeed" in training_arguments_kwargs:
|
83 |
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
84 |
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|