winglian commited on
Commit
aa3c3f9
1 Parent(s): f6d1fa4

optimize dataloading to use cache, fix model token embedding sizes

Browse files
src/axolotl/utils/data.py CHANGED
@@ -31,13 +31,7 @@ from axolotl.prompters import (
31
  )
32
 
33
 
34
- def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
35
- max_packed_sequence_len = (
36
- cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
37
- )
38
- max_packed_sequence_len = min(
39
- max_packed_sequence_len, cfg.sequence_len
40
- ) # make sure we don't accidentally set it larger than sequence_len
41
  ds_hash = str(
42
  md5(
43
  (
@@ -54,7 +48,7 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
54
  )
55
 
56
  if any(prepared_ds_path.glob("*")):
57
- logging.info(f"Loading prepared dataset from disk ay {prepared_ds_path}...")
58
  dataset = load_from_disk(str(prepared_ds_path))
59
  logging.info("Prepared dataset loaded from disk...")
60
  else:
@@ -153,14 +147,78 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
153
  )
154
  dataset.save_to_disk(prepared_ds_path)
155
 
 
 
 
 
 
 
 
 
 
 
 
156
  if cfg.max_packed_sequence_len is not None:
157
- constant_len_dataset = ConstantLengthDataset(
158
- tokenizer,
159
- [dataset],
160
- seq_length=max_packed_sequence_len,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  )
162
- logging.info(f"packing master dataset to len: {cfg.max_packed_sequence_len}")
163
- dataset = Dataset.from_list([_ for _ in constant_len_dataset])
 
 
 
 
 
 
 
 
 
 
164
 
165
  if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
166
  logging.info(
 
31
  )
32
 
33
 
34
+ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_path):
 
 
 
 
 
 
35
  ds_hash = str(
36
  md5(
37
  (
 
48
  )
49
 
50
  if any(prepared_ds_path.glob("*")):
51
+ logging.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
52
  dataset = load_from_disk(str(prepared_ds_path))
53
  logging.info("Prepared dataset loaded from disk...")
54
  else:
 
147
  )
148
  dataset.save_to_disk(prepared_ds_path)
149
 
150
+ return dataset
151
+
152
+
153
+ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
154
+ max_packed_sequence_len = (
155
+ cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
156
+ )
157
+ max_packed_sequence_len = min(
158
+ max_packed_sequence_len, cfg.sequence_len
159
+ ) # make sure we don't accidentally set it larger than sequence_len
160
+
161
  if cfg.max_packed_sequence_len is not None:
162
+ # see if we can go ahead and load the stacked dataset
163
+
164
+ ds_hash = str(
165
+ md5(
166
+ (
167
+ str(cfg.sequence_len)
168
+ + "@"
169
+ + str(max_packed_sequence_len)
170
+ + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
171
+ ).encode("utf-8")
172
+ ).hexdigest()
173
+ )
174
+ prepared_ds_path = (
175
+ Path(cfg.dataset_prepared_path) / ds_hash
176
+ if cfg.dataset_prepared_path
177
+ else Path(default_dataset_prepared_path) / ds_hash
178
+ )
179
+
180
+ if any(prepared_ds_path.glob("*")):
181
+ logging.info(
182
+ f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
183
+ )
184
+ dataset = load_from_disk(str(prepared_ds_path))
185
+ logging.info("Prepared packed dataset loaded from disk...")
186
+ else:
187
+ dataset = load_tokenized_prepared_datasets(
188
+ tokenizer, cfg, default_dataset_prepared_path
189
+ )
190
+
191
+ constant_len_dataset = ConstantLengthDataset(
192
+ tokenizer,
193
+ [dataset],
194
+ seq_length=max_packed_sequence_len,
195
+ )
196
+ logging.info(
197
+ f"packing master dataset to len: {cfg.max_packed_sequence_len}"
198
+ )
199
+ dataset = Dataset.from_list([_ for _ in constant_len_dataset])
200
+
201
+ if cfg.local_rank == 0:
202
+ logging.info(
203
+ f"Saving packed prepared dataset to disk... {prepared_ds_path}"
204
+ )
205
+ dataset.save_to_disk(prepared_ds_path)
206
+ else:
207
+ dataset = load_tokenized_prepared_datasets(
208
+ tokenizer, cfg, default_dataset_prepared_path
209
  )
210
+
211
+ # filter out bad data
212
+ dataset = Dataset.from_list(
213
+ [
214
+ d
215
+ for d in dataset
216
+ if len(d["input_ids"]) > cfg.sequence_len
217
+ and len(d["input_ids"]) > 0
218
+ and len(d["input_ids"]) == len(d["attention_mask"])
219
+ and len(d["input_ids"]) == len(d["labels"])
220
+ ]
221
+ )
222
 
223
  if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
224
  logging.info(
src/axolotl/utils/models.py CHANGED
@@ -181,6 +181,8 @@ def load_model(
181
  for k, v in cfg.tokens.items():
182
  tokenizer.add_special_tokens({k: v})
183
 
 
 
184
  if cfg.adapter and load_in_8bit and not cfg.load_4bit:
185
  logging.info("converting PEFT model w/ prepare_model_for_int8_training")
186
  model = prepare_model_for_int8_training(model)
 
181
  for k, v in cfg.tokens.items():
182
  tokenizer.add_special_tokens({k: v})
183
 
184
+ model.resize_token_embeddings(len(tokenizer))
185
+
186
  if cfg.adapter and load_in_8bit and not cfg.load_4bit:
187
  logging.info("converting PEFT model w/ prepare_model_for_int8_training")
188
  model = prepare_model_for_int8_training(model)