tmm1 commited on
Commit
8cec513
β€’
1 Parent(s): a13e45d

extract module for working with cfg

Browse files
scripts/finetune.py CHANGED
@@ -19,6 +19,7 @@ from transformers import GenerationConfig, TextStreamer
19
 
20
  from axolotl.logging_config import configure_logging
21
  from axolotl.utils.bench import log_gpu_memory_usage
 
22
  from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
23
  from axolotl.utils.dict import DictDefault
24
  from axolotl.utils.distributed import barrier, is_main_process
@@ -29,7 +30,6 @@ from axolotl.utils.trainer import (
29
  process_datasets_for_packing,
30
  setup_trainer,
31
  )
32
- from axolotl.utils.validation import validate_config
33
  from axolotl.utils.wandb import setup_wandb_env_vars
34
 
35
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -44,27 +44,6 @@ DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
44
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
45
 
46
 
47
- def choose_device(cfg):
48
- def get_device():
49
- try:
50
- if torch.cuda.is_available():
51
- return f"cuda:{cfg.local_rank}"
52
-
53
- if torch.backends.mps.is_available():
54
- return "mps"
55
-
56
- raise SystemError("No CUDA/mps device found")
57
- except Exception: # pylint: disable=broad-exception-caught
58
- return "cpu"
59
-
60
- cfg.device = get_device()
61
- if cfg.device_map != "auto":
62
- if cfg.device.startswith("cuda"):
63
- cfg.device_map = {"": cfg.local_rank}
64
- else:
65
- cfg.device_map = {"": cfg.device}
66
-
67
-
68
  def get_multi_line_input() -> Optional[str]:
69
  print("Give me an instruction (Ctrl + D to finish): ")
70
  instruction = ""
@@ -194,31 +173,9 @@ def train(
194
 
195
  validate_config(cfg)
196
 
197
- # setup some derived config / hyperparams
198
- cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
199
- cfg.batch_size // cfg.micro_batch_size
200
- )
201
- cfg.batch_size = (
202
- cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
203
- )
204
- cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
205
- cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
206
- choose_device(cfg)
207
- cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
208
- if cfg.ddp:
209
- cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
210
- cfg.batch_size = cfg.batch_size * cfg.world_size
211
 
212
  setup_wandb_env_vars(cfg)
213
- if cfg.device == "mps":
214
- cfg.load_in_8bit = False
215
- cfg.tf32 = False
216
- if cfg.bf16:
217
- cfg.fp16 = True
218
- cfg.bf16 = False
219
-
220
- if cfg.tf32:
221
- torch.backends.cuda.matmul.allow_tf32 = True
222
 
223
  # load the tokenizer first
224
  tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
 
19
 
20
  from axolotl.logging_config import configure_logging
21
  from axolotl.utils.bench import log_gpu_memory_usage
22
+ from axolotl.utils.config import normalize_config, validate_config
23
  from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
24
  from axolotl.utils.dict import DictDefault
25
  from axolotl.utils.distributed import barrier, is_main_process
 
30
  process_datasets_for_packing,
31
  setup_trainer,
32
  )
 
33
  from axolotl.utils.wandb import setup_wandb_env_vars
34
 
35
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
 
44
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def get_multi_line_input() -> Optional[str]:
48
  print("Give me an instruction (Ctrl + D to finish): ")
49
  instruction = ""
 
173
 
174
  validate_config(cfg)
175
 
176
+ normalize_config(cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  setup_wandb_env_vars(cfg)
 
 
 
 
 
 
 
 
 
179
 
180
  # load the tokenizer first
181
  tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
src/axolotl/utils/{validation.py β†’ config.py} RENAMED
@@ -1,12 +1,60 @@
1
- """Module for validating config files"""
2
 
3
  import logging
 
4
 
5
  import torch
6
 
7
  LOG = logging.getLogger("axolotl")
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def validate_config(cfg):
11
  if cfg.max_packed_sequence_len and cfg.sample_packing:
12
  raise ValueError(
 
1
+ """Module for working with config dicts"""
2
 
3
  import logging
4
+ import os
5
 
6
  import torch
7
 
8
  LOG = logging.getLogger("axolotl")
9
 
10
 
11
+ def choose_device(cfg):
12
+ def get_device():
13
+ try:
14
+ if torch.cuda.is_available():
15
+ return f"cuda:{cfg.local_rank}"
16
+
17
+ if torch.backends.mps.is_available():
18
+ return "mps"
19
+
20
+ raise SystemError("No CUDA/mps device found")
21
+ except Exception: # pylint: disable=broad-exception-caught
22
+ return "cpu"
23
+
24
+ cfg.device = get_device()
25
+ if cfg.device_map != "auto":
26
+ if cfg.device.startswith("cuda"):
27
+ cfg.device_map = {"": cfg.local_rank}
28
+ else:
29
+ cfg.device_map = {"": cfg.device}
30
+
31
+
32
+ def normalize_config(cfg):
33
+ # setup some derived config / hyperparams
34
+ cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
35
+ cfg.batch_size // cfg.micro_batch_size
36
+ )
37
+ cfg.batch_size = (
38
+ cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
39
+ )
40
+ cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
41
+ cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
42
+ choose_device(cfg)
43
+ cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
44
+ if cfg.ddp:
45
+ cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
46
+ cfg.batch_size = cfg.batch_size * cfg.world_size
47
+
48
+ if cfg.device == "mps":
49
+ cfg.load_in_8bit = False
50
+ cfg.tf32 = False
51
+ if cfg.bf16:
52
+ cfg.fp16 = True
53
+ cfg.bf16 = False
54
+ else:
55
+ torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
56
+
57
+
58
  def validate_config(cfg):
59
  if cfg.max_packed_sequence_len and cfg.sample_packing:
60
  raise ValueError(
tests/test_validation.py CHANGED
@@ -6,8 +6,8 @@ from typing import Optional
6
 
7
  import pytest
8
 
 
9
  from axolotl.utils.dict import DictDefault
10
- from axolotl.utils.validation import validate_config
11
 
12
 
13
  class ValidationTest(unittest.TestCase):
 
6
 
7
  import pytest
8
 
9
+ from axolotl.utils.config import validate_config
10
  from axolotl.utils.dict import DictDefault
 
11
 
12
 
13
  class ValidationTest(unittest.TestCase):