winglian commited on
Commit
ad2b48c
1 Parent(s): 9190ada

fdsp config dict fix, todo list, add torchdistx support

Browse files
TODO.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # todo list
2
+
3
+ - [] Validation of parameters for combinations that won't work
4
+
5
+
6
+
7
+ ## things that are known not to work
8
+
9
+ - FSDP offload and gradient_checkpointing - https://github.com/pytorch/pytorch/issues/82203
10
+ - adamw_bnb_8bit doesn't play well with FSDP offload
src/axolotl/utils/models.py CHANGED
@@ -179,6 +179,11 @@ def load_model(
179
  m.scales = m.scales.half()
180
  m.bias = m.bias.half()
181
 
 
 
 
 
 
182
  # TODO resume_from_checkpoint handling
183
  return model, tokenizer, lora_config
184
 
 
179
  m.scales = m.scales.half()
180
  m.bias = m.bias.half()
181
 
182
+ if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) > 1:
183
+ model.is_parallelizable = True
184
+ model.model_parallel = True
185
+
186
+
187
  # TODO resume_from_checkpoint handling
188
  return model, tokenizer, lora_config
189
 
src/axolotl/utils/trainer.py CHANGED
@@ -1,5 +1,7 @@
 
1
  import math
2
  import os
 
3
  from pathlib import Path
4
 
5
  import bitsandbytes as bnb
@@ -35,9 +37,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
35
  else:
36
  training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
37
  if cfg.fsdp:
38
- training_arguments_kwargs["fsdp"] = cfg.fsdp.split(" ")
39
- if cfg.fsdp_transformer_layer_cls_to_wrap:
40
- training_arguments_kwargs["fsdp_transformer_layer_cls_to_wrap"] = cfg.fsdp_transformer_layer_cls_to_wrap
41
 
42
 
43
  # deepspeed
@@ -73,6 +75,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
73
 
74
  trainer_kwargs = {}
75
 
 
 
 
 
76
  if cfg.optimizer == "adam8bit" and not cfg.load_4bit and not "deepspeed" in training_arguments_kwargs:
77
  decay_parameters = get_parameter_names(model, [nn.LayerNorm])
78
  decay_parameters = [name for name in decay_parameters if "bias" not in name]
 
1
+ import importlib
2
  import math
3
  import os
4
+ import sys
5
  from pathlib import Path
6
 
7
  import bitsandbytes as bnb
 
37
  else:
38
  training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
39
  if cfg.fsdp:
40
+ training_arguments_kwargs["fsdp"] = cfg.fsdp
41
+ if cfg.fsdp_config:
42
+ training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
43
 
44
 
45
  # deepspeed
 
75
 
76
  trainer_kwargs = {}
77
 
78
+ if cfg.optimizer == "adamw_anyprecision":
79
+ if Path(cfg.torchdistx_path).exists():
80
+ sys.path.append(cfg.torchdistx_path)
81
+ torchdistx = importlib.import_module('torchdistx')
82
  if cfg.optimizer == "adam8bit" and not cfg.load_4bit and not "deepspeed" in training_arguments_kwargs:
83
  decay_parameters = get_parameter_names(model, [nn.LayerNorm])
84
  decay_parameters = [name for name in decay_parameters if "bias" not in name]