winglian commited on
Commit
cda52dc
·
unverified ·
1 Parent(s): e799e08

support for explicit test_dataset definition for evals (#786)

Browse files
src/axolotl/utils/config.py CHANGED
@@ -519,6 +519,11 @@ def validate_config(cfg):
519
  "bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention"
520
  )
521
 
 
 
 
 
 
522
  # TODO
523
  # MPT 7b
524
  # https://github.com/facebookresearch/bitsandbytes/issues/25
 
519
  "bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention"
520
  )
521
 
522
+ if cfg.test_datasets and cfg.val_set_size:
523
+ raise ValueError(
524
+ "non-zero val_set_size should not be used with test_datasets configuration"
525
+ )
526
+
527
  # TODO
528
  # MPT 7b
529
  # https://github.com/facebookresearch/bitsandbytes/issues/25
src/axolotl/utils/data.py CHANGED
@@ -4,7 +4,7 @@ import hashlib
4
  import logging
5
  from collections import defaultdict
6
  from pathlib import Path
7
- from typing import Dict, List, Tuple, Union
8
 
9
  import torch
10
  from datasets import (
@@ -65,9 +65,17 @@ def prepare_dataset(cfg, tokenizer):
65
  prompters = []
66
  if not cfg.pretraining_dataset:
67
  with zero_first(is_main_process()):
68
- train_dataset, eval_dataset, prompters = load_prepare_datasets(
69
- tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
70
- )
 
 
 
 
 
 
 
 
71
  else:
72
  path = cfg.pretraining_dataset
73
  name = None
@@ -108,8 +116,12 @@ def prepare_dataset(cfg, tokenizer):
108
 
109
 
110
  def load_tokenized_prepared_datasets(
111
- tokenizer, cfg, default_dataset_prepared_path
 
 
 
112
  ) -> Tuple[DatasetDict, List[Prompter]]:
 
113
  tokenizer_name = tokenizer.__class__.__name__
114
  ds_hash = str(
115
  md5(
@@ -126,7 +138,7 @@ def load_tokenized_prepared_datasets(
126
  sorted(
127
  [
128
  f"{d.path}:{d.type}:{d.shards}:{d.conversation}"
129
- for d in cfg.datasets
130
  ]
131
  )
132
  )
@@ -149,7 +161,7 @@ def load_tokenized_prepared_datasets(
149
  f"{cfg.push_dataset_to_hub}/{ds_hash}",
150
  token=use_auth_token,
151
  )
152
- dataset = dataset["train"]
153
  except Exception: # pylint: disable=broad-except # nosec
154
  pass
155
 
@@ -188,8 +200,8 @@ def load_tokenized_prepared_datasets(
188
  yield dataset
189
 
190
  # pylint: disable=invalid-name
191
- for config_dataset in for_d_in_datasets(cfg.datasets):
192
- ds: Union[Dataset, DatasetDict] = None
193
  ds_from_hub = False
194
  try:
195
  load_dataset(
@@ -342,16 +354,6 @@ def load_tokenized_prepared_datasets(
342
  )
343
  if not ds:
344
  raise ValueError("unhandled dataset load")
345
- # support for using a subset of the data
346
- if config_dataset.shards:
347
- if "train" in ds:
348
- ds = ds.shuffle(seed=seed)["train"].shard(
349
- num_shards=config_dataset.shards, index=0
350
- )
351
- else:
352
- ds = ds.shuffle(seed=seed).shard(
353
- num_shards=config_dataset.shards, index=0
354
- )
355
 
356
  d_base_type = d_prompt_style = None
357
  d_type = config_dataset.type
@@ -359,17 +361,21 @@ def load_tokenized_prepared_datasets(
359
  d_type_split = d_type.split(":")
360
  d_base_type = d_type_split[0]
361
  d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
362
- if "train" in ds:
363
- ds = ds["train"]
364
- elif (
365
- isinstance(ds, DatasetDict)
366
- and config_dataset.train_on_split
367
- and config_dataset.train_on_split in ds
368
- ):
369
- ds = ds[config_dataset.train_on_split]
370
  elif isinstance(ds, DatasetDict):
371
  raise ValueError(
372
- f"no train split found for dataset {config_dataset.path}, you may specify a split with 'train_on_split: `"
 
 
 
 
 
 
 
373
  )
374
 
375
  dataset_wrapper, dataset_prompter = get_dataset_wrapper(
@@ -428,6 +434,7 @@ def load_prepare_datasets(
428
  tokenizer: PreTrainedTokenizerBase,
429
  cfg,
430
  default_dataset_prepared_path,
 
431
  ) -> Tuple[Dataset, Dataset, List[Prompter]]:
432
  dataset, prompters = load_tokenized_prepared_datasets(
433
  tokenizer, cfg, default_dataset_prepared_path
@@ -442,7 +449,7 @@ def load_prepare_datasets(
442
  index=cfg.dataset_shard_idx,
443
  )
444
 
445
- if cfg.val_set_size:
446
  # ensure we end up with the same fingerprint by doing rank0 first and being able to cache
447
  to_hash_train = (
448
  dataset._fingerprint # pylint: disable=protected-access
@@ -475,6 +482,9 @@ def load_prepare_datasets(
475
 
476
  train_dataset = dataset["train"]
477
  eval_dataset = dataset["test"]
 
 
 
478
  else:
479
  train_dataset = dataset
480
  eval_dataset = None
 
4
  import logging
5
  from collections import defaultdict
6
  from pathlib import Path
7
+ from typing import Dict, List, Optional, Tuple, Union
8
 
9
  import torch
10
  from datasets import (
 
65
  prompters = []
66
  if not cfg.pretraining_dataset:
67
  with zero_first(is_main_process()):
68
+ if cfg.test_datasets:
69
+ train_dataset, _, prompters = load_prepare_datasets(
70
+ tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train"
71
+ )
72
+ _, eval_dataset, _ = load_prepare_datasets(
73
+ tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="test"
74
+ )
75
+ else:
76
+ train_dataset, eval_dataset, prompters = load_prepare_datasets(
77
+ tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
78
+ )
79
  else:
80
  path = cfg.pretraining_dataset
81
  name = None
 
116
 
117
 
118
  def load_tokenized_prepared_datasets(
119
+ tokenizer,
120
+ cfg,
121
+ default_dataset_prepared_path,
122
+ split="train",
123
  ) -> Tuple[DatasetDict, List[Prompter]]:
124
+ cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
125
  tokenizer_name = tokenizer.__class__.__name__
126
  ds_hash = str(
127
  md5(
 
138
  sorted(
139
  [
140
  f"{d.path}:{d.type}:{d.shards}:{d.conversation}"
141
+ for d in cfg_datasets
142
  ]
143
  )
144
  )
 
161
  f"{cfg.push_dataset_to_hub}/{ds_hash}",
162
  token=use_auth_token,
163
  )
164
+ dataset = dataset[split]
165
  except Exception: # pylint: disable=broad-except # nosec
166
  pass
167
 
 
200
  yield dataset
201
 
202
  # pylint: disable=invalid-name
203
+ for config_dataset in for_d_in_datasets(cfg_datasets):
204
+ ds: Optional[Union[Dataset, DatasetDict]] = None
205
  ds_from_hub = False
206
  try:
207
  load_dataset(
 
354
  )
355
  if not ds:
356
  raise ValueError("unhandled dataset load")
 
 
 
 
 
 
 
 
 
 
357
 
358
  d_base_type = d_prompt_style = None
359
  d_type = config_dataset.type
 
361
  d_type_split = d_type.split(":")
362
  d_base_type = d_type_split[0]
363
  d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
364
+
365
+ if config_dataset.split and config_dataset.split in ds:
366
+ ds = ds[config_dataset.split]
367
+ elif split in ds:
368
+ ds = ds[split]
 
 
 
369
  elif isinstance(ds, DatasetDict):
370
  raise ValueError(
371
+ f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `"
372
+ )
373
+
374
+ # support for using a subset of the data
375
+ if config_dataset.shards:
376
+ shards_idx = config_dataset.get("shards_idx", 0)
377
+ ds = ds.shuffle(seed=seed).shard(
378
+ num_shards=config_dataset.shards, index=shards_idx
379
  )
380
 
381
  dataset_wrapper, dataset_prompter = get_dataset_wrapper(
 
434
  tokenizer: PreTrainedTokenizerBase,
435
  cfg,
436
  default_dataset_prepared_path,
437
+ split="train",
438
  ) -> Tuple[Dataset, Dataset, List[Prompter]]:
439
  dataset, prompters = load_tokenized_prepared_datasets(
440
  tokenizer, cfg, default_dataset_prepared_path
 
449
  index=cfg.dataset_shard_idx,
450
  )
451
 
452
+ if split == "train" and cfg.val_set_size:
453
  # ensure we end up with the same fingerprint by doing rank0 first and being able to cache
454
  to_hash_train = (
455
  dataset._fingerprint # pylint: disable=protected-access
 
482
 
483
  train_dataset = dataset["train"]
484
  eval_dataset = dataset["test"]
485
+ elif split == "test":
486
+ train_dataset = None
487
+ eval_dataset = dataset
488
  else:
489
  train_dataset = dataset
490
  eval_dataset = None