tmm1 commited on
Commit
2e22404
1 Parent(s): be294fd

add utils.data.prepare_dataset

Browse files
Files changed (2) hide show
  1. scripts/finetune.py +3 -34
  2. src/axolotl/utils/data.py +35 -0
scripts/finetune.py CHANGED
@@ -19,16 +19,11 @@ from transformers import GenerationConfig, TextStreamer
19
 
20
  from axolotl.logging_config import configure_logging
21
  from axolotl.utils.config import normalize_config, validate_config
22
- from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
23
  from axolotl.utils.dict import DictDefault
24
- from axolotl.utils.distributed import is_main_process, zero_first
25
  from axolotl.utils.models import load_model, load_tokenizer
26
  from axolotl.utils.tokenization import check_dataset_labels
27
- from axolotl.utils.trainer import (
28
- calculate_total_num_steps,
29
- process_datasets_for_packing,
30
- setup_trainer,
31
- )
32
  from axolotl.utils.wandb import setup_wandb_env_vars
33
 
34
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -39,7 +34,6 @@ configure_logging()
39
  LOG = logging.getLogger("axolotl.scripts")
40
 
41
 
42
- DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
43
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
44
 
45
 
@@ -183,32 +177,7 @@ def train(
183
  if (
184
  check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
185
  ): # don't need to load dataset for these
186
- if not cfg.pretraining_dataset:
187
- train_dataset, eval_dataset = load_prepare_datasets(
188
- tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
189
- )
190
- else:
191
- train_dataset = load_pretraining_dataset(
192
- cfg.pretraining_dataset,
193
- tokenizer,
194
- max_tokens=cfg.sequence_len,
195
- seed=cfg.seed or 42,
196
- )
197
- # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
198
- train_dataset = train_dataset.with_format("torch")
199
- eval_dataset = None
200
-
201
- with zero_first(is_main_process()):
202
- train_dataset, eval_dataset = process_datasets_for_packing(
203
- cfg, train_dataset, eval_dataset
204
- )
205
- if cfg.max_steps:
206
- total_num_steps = min(
207
- calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
208
- )
209
- LOG.info(f"Maximum number of steps set at {total_num_steps}")
210
- else:
211
- total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
212
 
213
  if cfg.debug or "debug" in kwargs:
214
  LOG.info("check_dataset_labels...")
 
19
 
20
  from axolotl.logging_config import configure_logging
21
  from axolotl.utils.config import normalize_config, validate_config
22
+ from axolotl.utils.data import prepare_dataset
23
  from axolotl.utils.dict import DictDefault
 
24
  from axolotl.utils.models import load_model, load_tokenizer
25
  from axolotl.utils.tokenization import check_dataset_labels
26
+ from axolotl.utils.trainer import setup_trainer
 
 
 
 
27
  from axolotl.utils.wandb import setup_wandb_env_vars
28
 
29
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
 
34
  LOG = logging.getLogger("axolotl.scripts")
35
 
36
 
 
37
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
38
 
39
 
 
177
  if (
178
  check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
179
  ): # don't need to load dataset for these
180
+ train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
  if cfg.debug or "debug" in kwargs:
183
  LOG.info("check_dataset_labels...")
src/axolotl/utils/data.py CHANGED
@@ -42,8 +42,43 @@ from axolotl.prompters import (
42
  SummarizeTLDRPrompter,
43
  )
44
  from axolotl.utils.distributed import is_main_process, zero_first
 
 
 
 
45
 
46
  LOG = logging.getLogger("axolotl")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
 
49
  def load_tokenized_prepared_datasets(
 
42
  SummarizeTLDRPrompter,
43
  )
44
  from axolotl.utils.distributed import is_main_process, zero_first
45
+ from axolotl.utils.trainer import (
46
+ calculate_total_num_steps,
47
+ process_datasets_for_packing,
48
+ )
49
 
50
  LOG = logging.getLogger("axolotl")
51
+ DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
52
+
53
+
54
+ def prepare_dataset(cfg, tokenizer):
55
+ if not cfg.pretraining_dataset:
56
+ train_dataset, eval_dataset = load_prepare_datasets(
57
+ tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
58
+ )
59
+ else:
60
+ train_dataset = load_pretraining_dataset(
61
+ cfg.pretraining_dataset,
62
+ tokenizer,
63
+ max_tokens=cfg.sequence_len,
64
+ seed=cfg.seed or 42,
65
+ )
66
+ # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
67
+ train_dataset = train_dataset.with_format("torch")
68
+ eval_dataset = None
69
+
70
+ with zero_first(is_main_process()):
71
+ train_dataset, eval_dataset = process_datasets_for_packing(
72
+ cfg, train_dataset, eval_dataset
73
+ )
74
+ if cfg.max_steps:
75
+ total_num_steps = min(
76
+ calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
77
+ )
78
+ LOG.info(f"Maximum number of steps set at {total_num_steps}")
79
+ else:
80
+ total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
81
+ return train_dataset, eval_dataset, total_num_steps
82
 
83
 
84
  def load_tokenized_prepared_datasets(