jinwonkim93 jinwonkim93@github.com winglian commited on
Commit
553c80f
·
unverified ·
1 Parent(s): eb4c994

streaming multipack for pretraining dataset (#959)

Browse files

* [Feat] streaming multipack

* WIP make continued pretraining work w multipack

* fix up hadrcoding, lint

* fix dict check

* update test for updated pretraining multipack code

* fix hardcoded data collator fix for multipack pretraining

* fix the collator to be the max length for multipack pretraining

* don't bother with latest tag for test

* cleanup docker build/test

---------

Co-authored-by: jinwonkim93@github.com <jinwonkim>
Co-authored-by: Wing Lian <wing.lian@gmail.com>

.github/workflows/tests-docker.yml CHANGED
@@ -20,7 +20,6 @@ jobs:
20
  python_version: "3.10"
21
  pytorch: 2.0.1
22
  axolotl_extras:
23
- is_latest: true
24
  - cuda: 121
25
  cuda_version: 12.1.0
26
  python_version: "3.10"
@@ -37,7 +36,7 @@ jobs:
37
  images: winglian/axolotl
38
  - name: Set up Docker Buildx
39
  uses: docker/setup-buildx-action@v3
40
- - name: Build and export to Docker
41
  uses: docker/build-push-action@v5
42
  with:
43
  context: .
@@ -49,8 +48,7 @@ jobs:
49
  file: ./docker/Dockerfile
50
  tags: |
51
  ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
52
- ${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
53
  labels: ${{ steps.metadata.outputs.labels }}
54
- - name: Unit Tests
55
  run: |
56
  docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
 
20
  python_version: "3.10"
21
  pytorch: 2.0.1
22
  axolotl_extras:
 
23
  - cuda: 121
24
  cuda_version: 12.1.0
25
  python_version: "3.10"
 
36
  images: winglian/axolotl
37
  - name: Set up Docker Buildx
38
  uses: docker/setup-buildx-action@v3
39
+ - name: Build Docker image
40
  uses: docker/build-push-action@v5
41
  with:
42
  context: .
 
48
  file: ./docker/Dockerfile
49
  tags: |
50
  ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
 
51
  labels: ${{ steps.metadata.outputs.labels }}
52
+ - name: Unit Tests w docker image
53
  run: |
54
  docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
examples/tiny-llama/pretrain.yml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
2
+
3
+ model_type: LlamaForCausalLM
4
+ tokenizer_type: LlamaTokenizer
5
+ is_llama_derived_model: true
6
+
7
+ load_in_8bit: false
8
+ load_in_4bit: false
9
+ strict: false
10
+
11
+ max_steps: 200
12
+ pretraining_dataset:
13
+ path: c4
14
+ name: en
15
+ dataset_prepared_path:
16
+ val_set_size: 0.0
17
+ output_dir: ./model-out
18
+
19
+ sequence_len: 2048
20
+ sample_packing: true
21
+
22
+ wandb_project:
23
+ wandb_entity:
24
+ wandb_watch:
25
+ wandb_name:
26
+ wandb_log_model:
27
+
28
+ gradient_accumulation_steps: 4
29
+ micro_batch_size: 2
30
+ num_epochs: 4
31
+ optimizer: adamw_bnb_8bit
32
+ lr_scheduler: cosine
33
+ learning_rate: 0.0002
34
+
35
+ train_on_inputs: false
36
+ group_by_length: false
37
+ bf16: true
38
+ fp16: false
39
+ tf32: false
40
+
41
+ gradient_checkpointing: true
42
+ early_stopping_patience:
43
+ resume_from_checkpoint:
44
+ local_rank:
45
+ logging_steps: 1
46
+ xformers_attention:
47
+ flash_attention: true
48
+
49
+ warmup_steps: 10
50
+ evals_per_epoch:
51
+ eval_table_size:
52
+ saves_per_epoch: 1
53
+ debug:
54
+ deepspeed:
55
+ weight_decay: 0.0
56
+ fsdp:
57
+ fsdp_config:
58
+ special_tokens:
src/axolotl/core/trainer_builder.py CHANGED
@@ -60,6 +60,12 @@ class AxolotlTrainingArguments(TrainingArguments):
60
  default=False,
61
  metadata={"help": "Use quadratic warmup for cosine scheduling."},
62
  )
 
 
 
 
 
 
63
  sample_packing: bool = field(
64
  default=False,
65
  metadata={"help": "Use sample packing for efficient training."},
@@ -157,7 +163,7 @@ class AxolotlTrainer(Trainer):
157
  return self.lr_scheduler
158
 
159
  def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
160
- if self.args.sample_packing:
161
  return MultipackBatchSampler(
162
  RandomSampler(self.train_dataset),
163
  self.args.train_batch_size,
@@ -193,7 +199,7 @@ class AxolotlTrainer(Trainer):
193
  return super()._get_eval_sampler(eval_dataset)
194
 
195
  def get_train_dataloader(self) -> DataLoader:
196
- if self.args.sample_packing:
197
  train_dataset = self.train_dataset
198
  train_dataset = train_dataset.remove_columns(["length"])
199
  data_collator = self.data_collator
@@ -768,6 +774,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
768
  training_arguments_kwargs
769
  )
770
  training_arguments_kwargs["model_type"] = self.cfg.model_config_type
 
771
 
772
  if self.cfg.neftune_noise_alpha is not None:
773
  training_arguments_kwargs[
@@ -808,7 +815,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
808
  train_dataset=self.train_dataset,
809
  eval_dataset=self.eval_dataset,
810
  args=training_args,
811
- data_collator=self.build_collator(**data_collator_kwargs),
812
  bench_data_collator=transformers.DataCollatorForSeq2Seq(
813
  self.tokenizer,
814
  return_tensors="pt",
@@ -829,7 +836,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
829
 
830
  return trainer
831
 
832
- def build_collator(self, **kwargs):
 
 
 
833
  if self.cfg.model_config_type == "mamba":
834
  return MambaDataCollator(tokenizer=self.tokenizer)
835
 
 
60
  default=False,
61
  metadata={"help": "Use quadratic warmup for cosine scheduling."},
62
  )
63
+ pretraining: bool = field(
64
+ default=False,
65
+ metadata={
66
+ "help": "Indicates to trainer whether we are doing continued pretraining."
67
+ },
68
+ )
69
  sample_packing: bool = field(
70
  default=False,
71
  metadata={"help": "Use sample packing for efficient training."},
 
163
  return self.lr_scheduler
164
 
165
  def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
166
+ if self.args.sample_packing and not self.args.pretraining:
167
  return MultipackBatchSampler(
168
  RandomSampler(self.train_dataset),
169
  self.args.train_batch_size,
 
199
  return super()._get_eval_sampler(eval_dataset)
200
 
201
  def get_train_dataloader(self) -> DataLoader:
202
+ if self.args.sample_packing and not self.args.pretraining:
203
  train_dataset = self.train_dataset
204
  train_dataset = train_dataset.remove_columns(["length"])
205
  data_collator = self.data_collator
 
774
  training_arguments_kwargs
775
  )
776
  training_arguments_kwargs["model_type"] = self.cfg.model_config_type
777
+ training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
778
 
779
  if self.cfg.neftune_noise_alpha is not None:
780
  training_arguments_kwargs[
 
815
  train_dataset=self.train_dataset,
816
  eval_dataset=self.eval_dataset,
817
  args=training_args,
818
+ data_collator=self.build_collator(training_args, **data_collator_kwargs),
819
  bench_data_collator=transformers.DataCollatorForSeq2Seq(
820
  self.tokenizer,
821
  return_tensors="pt",
 
836
 
837
  return trainer
838
 
839
+ def build_collator(self, training_args: AxolotlTrainingArguments, **kwargs):
840
+ if training_args.pretraining:
841
+ return None
842
+
843
  if self.cfg.model_config_type == "mamba":
844
  return MambaDataCollator(tokenizer=self.tokenizer)
845
 
src/axolotl/utils/collators.py CHANGED
@@ -178,3 +178,24 @@ class MambaDataCollator:
178
  "input_ids": input_ids,
179
  "labels": labels,
180
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  "input_ids": input_ids,
179
  "labels": labels,
180
  }
181
+
182
+
183
+ @dataclass
184
+ class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
185
+ """
186
+ Collator for multipack specific to the using the BatchSampler
187
+ """
188
+
189
+ def __call__(self, features, return_tensors=None):
190
+ chunked_data = {}
191
+ for feature in features.keys():
192
+ if feature == "length":
193
+ continue
194
+ if feature == "attention_mask":
195
+ arrays = [(1) * np.array(item) for item in features[feature]]
196
+ chunked_data[feature] = np.concatenate(arrays)
197
+ else:
198
+ arrays = [np.array(item) for item in features[feature]]
199
+ chunked_data[feature] = np.concatenate(arrays)
200
+ features = [chunked_data]
201
+ return super().__call__(features, return_tensors=return_tensors)
src/axolotl/utils/data.py CHANGED
@@ -2,6 +2,7 @@
2
  import functools
3
  import hashlib
4
  import logging
 
5
  from pathlib import Path
6
  from typing import Dict, List, Tuple, Union
7
 
@@ -14,6 +15,7 @@ from datasets import (
14
  load_from_disk,
15
  )
16
  from huggingface_hub import hf_hub_download
 
17
  from transformers import PreTrainedTokenizerBase
18
 
19
  from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
@@ -39,11 +41,14 @@ from axolotl.prompters import (
39
  SummarizeTLDRPrompter,
40
  UnsupportedPrompter,
41
  )
 
42
  from axolotl.utils.dict import DictDefault
43
  from axolotl.utils.distributed import is_main_process, zero_first
 
44
  from axolotl.utils.trainer import (
45
  calculate_total_num_steps,
46
  process_datasets_for_packing,
 
47
  )
48
 
49
  LOG = logging.getLogger("axolotl")
@@ -64,9 +69,17 @@ def prepare_dataset(cfg, tokenizer):
64
  tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
65
  )
66
  else:
 
 
 
 
 
 
67
  train_dataset = load_pretraining_dataset(
68
- cfg.pretraining_dataset,
69
  tokenizer,
 
 
70
  max_tokens=cfg.sequence_len,
71
  seed=cfg.seed or 42,
72
  )
@@ -806,9 +819,27 @@ def encode_pretraining(
806
  return ret
807
 
808
 
809
- def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
810
- encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
811
- dataset = load_dataset(path, streaming=True, split="train")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
812
  dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
813
  dataset = dataset.map(
814
  encode,
@@ -819,3 +850,63 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
819
  remove_columns=dataset.features.keys(),
820
  )
821
  return dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import functools
3
  import hashlib
4
  import logging
5
+ from collections import defaultdict
6
  from pathlib import Path
7
  from typing import Dict, List, Tuple, Union
8
 
 
15
  load_from_disk,
16
  )
17
  from huggingface_hub import hf_hub_download
18
+ from torch.utils.data import RandomSampler
19
  from transformers import PreTrainedTokenizerBase
20
 
21
  from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
 
41
  SummarizeTLDRPrompter,
42
  UnsupportedPrompter,
43
  )
44
+ from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
45
  from axolotl.utils.dict import DictDefault
46
  from axolotl.utils.distributed import is_main_process, zero_first
47
+ from axolotl.utils.samplers.multipack import MultipackBatchSampler
48
  from axolotl.utils.trainer import (
49
  calculate_total_num_steps,
50
  process_datasets_for_packing,
51
+ process_pretraining_datasets_for_packing,
52
  )
53
 
54
  LOG = logging.getLogger("axolotl")
 
69
  tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
70
  )
71
  else:
72
+ path = cfg.pretraining_dataset
73
+ name = None
74
+ if isinstance(cfg.pretraining_dataset, dict):
75
+ path = cfg.pretraining_dataset["path"]
76
+ name = cfg.pretraining_dataset["name"]
77
+
78
  train_dataset = load_pretraining_dataset(
79
+ path,
80
  tokenizer,
81
+ cfg,
82
+ name=name,
83
  max_tokens=cfg.sequence_len,
84
  seed=cfg.seed or 42,
85
  )
 
819
  return ret
820
 
821
 
822
+ def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42):
823
+ if cfg.sample_packing:
824
+ collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
825
+ tokenizer,
826
+ return_tensors="pt",
827
+ padding=True,
828
+ pad_to_multiple_of=max_tokens * cfg.micro_batch_size,
829
+ )
830
+ encode = functools.partial(
831
+ encode_packed_pretraining,
832
+ tokenizer,
833
+ collate_fn,
834
+ max_seq_length=max_tokens,
835
+ batch_size=cfg.micro_batch_size,
836
+ )
837
+ # set this to 1 so downstream data_loader doesn't try to increase the batch again
838
+ cfg.micro_batch_size = 1
839
+ else:
840
+ encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
841
+
842
+ dataset = load_dataset(path, streaming=True, split="train", name=name)
843
  dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
844
  dataset = dataset.map(
845
  encode,
 
850
  remove_columns=dataset.features.keys(),
851
  )
852
  return dataset
853
+
854
+
855
+ def encode_packed_pretraining(
856
+ tokenizer: PreTrainedTokenizerBase,
857
+ collate_fn,
858
+ examples: List[str],
859
+ max_seq_length: int = 2048,
860
+ batch_size: int = 4,
861
+ ) -> Dict[str, List]:
862
+ # pylint: disable=duplicate-code
863
+ # tokenize all the examples
864
+ # rows get split with stride (overlap)
865
+ res = tokenizer(
866
+ examples,
867
+ truncation=True,
868
+ max_length=max_seq_length - 1,
869
+ add_special_tokens=True,
870
+ return_overflowing_tokens=True,
871
+ stride=256,
872
+ )
873
+
874
+ input_ids = [seq + [tokenizer.eos_token_id] for seq in res["input_ids"]]
875
+ attention_mask = [seq + [1] for seq in res["attention_mask"]]
876
+
877
+ tokenized_examples = {
878
+ "input_ids": input_ids,
879
+ "attention_mask": attention_mask,
880
+ }
881
+
882
+ train_dataset = Dataset.from_dict(tokenized_examples)
883
+ train_dataset = process_pretraining_datasets_for_packing(
884
+ train_dataset, max_seq_length
885
+ )
886
+
887
+ sampler = MultipackBatchSampler(
888
+ RandomSampler(train_dataset),
889
+ batch_size=batch_size,
890
+ drop_last=True,
891
+ batch_max_len=batch_size * max_seq_length,
892
+ lengths=(
893
+ train_dataset.data.column("position_ids")
894
+ .to_pandas()
895
+ .apply(lambda x: x[-1] + 1)
896
+ .values
897
+ ),
898
+ )
899
+
900
+ chunked_data = defaultdict(list)
901
+
902
+ for data in sampler:
903
+ features = train_dataset[data]
904
+ features["labels"] = features["input_ids"].copy()
905
+ collated_features = collate_fn(features)
906
+
907
+ for feature in features.keys():
908
+ if feature == "length":
909
+ continue
910
+ chunked_data[feature].append(collated_features[feature].squeeze(0))
911
+
912
+ return chunked_data
src/axolotl/utils/trainer.py CHANGED
@@ -143,6 +143,16 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
143
  return train_dataset, eval_dataset
144
 
145
 
 
 
 
 
 
 
 
 
 
 
146
  def calculate_total_num_steps(cfg, train_dataset, update=True):
147
  if not cfg.total_num_tokens:
148
  total_num_tokens = np.sum(
 
143
  return train_dataset, eval_dataset
144
 
145
 
146
+ def process_pretraining_datasets_for_packing(train_dataset, sequence_len):
147
+ drop_long = partial(drop_long_seq, sequence_len=sequence_len)
148
+
149
+ train_dataset = train_dataset.filter(drop_long)
150
+ train_dataset = train_dataset.map(
151
+ add_position_ids,
152
+ )
153
+ return train_dataset
154
+
155
+
156
  def calculate_total_num_steps(cfg, train_dataset, update=True):
157
  if not cfg.total_num_tokens:
158
  total_num_tokens = np.sum(
tests/test_packed_pretraining.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module for testing streaming dataset sequence packing"""
2
+ import unittest
3
+ from functools import partial
4
+
5
+ import torch
6
+ from datasets import load_dataset
7
+ from torch.utils.data import DataLoader
8
+ from transformers import AutoTokenizer
9
+
10
+ from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
11
+ from axolotl.utils.data import encode_packed_pretraining
12
+
13
+
14
+ class TestPacking(unittest.TestCase):
15
+ """
16
+ Test class for packing streaming dataset sequences
17
+ """
18
+
19
+ def setUp(self) -> None:
20
+ # pylint: disable=duplicate-code
21
+ self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
22
+ self.tokenizer.pad_token = "</s>"
23
+ self.max_seq_length = 2048
24
+ self.batch_size = 2
25
+
26
+ def test_packing_stream_dataset(self):
27
+ # pylint: disable=duplicate-code
28
+ dataset = load_dataset(
29
+ "c4",
30
+ "en",
31
+ streaming=True,
32
+ )["train"]
33
+
34
+ collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
35
+ self.tokenizer,
36
+ return_tensors="pt",
37
+ padding=True,
38
+ pad_to_multiple_of=self.max_seq_length,
39
+ )
40
+
41
+ encode = partial(
42
+ encode_packed_pretraining,
43
+ self.tokenizer,
44
+ collate_fn,
45
+ max_seq_length=self.max_seq_length,
46
+ batch_size=self.batch_size,
47
+ )
48
+
49
+ dataset = dataset.map(
50
+ encode,
51
+ batched=True,
52
+ input_columns="text",
53
+ remove_columns=dataset.features.keys(),
54
+ )
55
+
56
+ trainer_loader = DataLoader(
57
+ dataset,
58
+ batch_size=1,
59
+ collate_fn=None,
60
+ drop_last=True,
61
+ )
62
+ idx = 0
63
+ for data in trainer_loader:
64
+ if idx > 10:
65
+ break
66
+ assert data["input_ids"].shape == torch.Size(
67
+ [1, self.batch_size * self.max_seq_length]
68
+ )
69
+ assert data["position_ids"].shape == torch.Size(
70
+ [1, self.batch_size * self.max_seq_length]
71
+ )
72
+ assert data["labels"].shape == torch.Size(
73
+ [1, self.batch_size * self.max_seq_length]
74
+ )
75
+ assert data["attention_mask"].shape == torch.Size(
76
+ [1, self.batch_size * self.max_seq_length]
77
+ )
78
+ idx += 1
79
+
80
+
81
+ if __name__ == "__main__":
82
+ unittest.main()