Determine FSDP/deepspeed settings on device select. (#883)
Browse files* Determine FSDP/deepspeed settings on device select.
Without this, the OS env check for accelerate will fail.
* rename and move env setup call
* chore: lint
---------
Co-authored-by: Karl-Johan Alm <kalle@gmail.com>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
src/axolotl/cli/__init__.py
CHANGED
|
@@ -29,6 +29,7 @@ from axolotl.utils.dict import DictDefault
|
|
| 29 |
from axolotl.utils.distributed import is_main_process
|
| 30 |
from axolotl.utils.models import load_tokenizer
|
| 31 |
from axolotl.utils.tokenization import check_dataset_labels
|
|
|
|
| 32 |
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
| 33 |
|
| 34 |
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
@@ -296,6 +297,8 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
|
|
| 296 |
|
| 297 |
validate_config(cfg)
|
| 298 |
|
|
|
|
|
|
|
| 299 |
normalize_config(cfg)
|
| 300 |
|
| 301 |
setup_wandb_env_vars(cfg)
|
|
|
|
| 29 |
from axolotl.utils.distributed import is_main_process
|
| 30 |
from axolotl.utils.models import load_tokenizer
|
| 31 |
from axolotl.utils.tokenization import check_dataset_labels
|
| 32 |
+
from axolotl.utils.trainer import prepare_optim_env
|
| 33 |
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
| 34 |
|
| 35 |
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
|
|
| 297 |
|
| 298 |
validate_config(cfg)
|
| 299 |
|
| 300 |
+
prepare_optim_env(cfg)
|
| 301 |
+
|
| 302 |
normalize_config(cfg)
|
| 303 |
|
| 304 |
setup_wandb_env_vars(cfg)
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -267,12 +267,14 @@ def setup_fsdp_envs(cfg):
|
|
| 267 |
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
|
| 268 |
|
| 269 |
|
| 270 |
-
def
|
| 271 |
if cfg.fsdp:
|
| 272 |
setup_fsdp_envs(cfg)
|
| 273 |
elif cfg.deepspeed:
|
| 274 |
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
| 275 |
|
|
|
|
|
|
|
| 276 |
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
|
| 277 |
trainer_builder.train_dataset = train_dataset
|
| 278 |
trainer_builder.eval_dataset = eval_dataset
|
|
|
|
| 267 |
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
|
| 268 |
|
| 269 |
|
| 270 |
+
def prepare_optim_env(cfg):
|
| 271 |
if cfg.fsdp:
|
| 272 |
setup_fsdp_envs(cfg)
|
| 273 |
elif cfg.deepspeed:
|
| 274 |
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
| 275 |
|
| 276 |
+
|
| 277 |
+
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
| 278 |
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
|
| 279 |
trainer_builder.train_dataset = train_dataset
|
| 280 |
trainer_builder.eval_dataset = eval_dataset
|