optimize dataloading to use cache, fix model token embedding sizes
Browse files- src/axolotl/utils/data.py +72 -14
- src/axolotl/utils/models.py +2 -0
src/axolotl/utils/data.py
CHANGED
@@ -31,13 +31,7 @@ from axolotl.prompters import (
|
|
31 |
)
|
32 |
|
33 |
|
34 |
-
def
|
35 |
-
max_packed_sequence_len = (
|
36 |
-
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
37 |
-
)
|
38 |
-
max_packed_sequence_len = min(
|
39 |
-
max_packed_sequence_len, cfg.sequence_len
|
40 |
-
) # make sure we don't accidentally set it larger than sequence_len
|
41 |
ds_hash = str(
|
42 |
md5(
|
43 |
(
|
@@ -54,7 +48,7 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
54 |
)
|
55 |
|
56 |
if any(prepared_ds_path.glob("*")):
|
57 |
-
logging.info(f"Loading prepared dataset from disk
|
58 |
dataset = load_from_disk(str(prepared_ds_path))
|
59 |
logging.info("Prepared dataset loaded from disk...")
|
60 |
else:
|
@@ -153,14 +147,78 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
153 |
)
|
154 |
dataset.save_to_disk(prepared_ds_path)
|
155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
if cfg.max_packed_sequence_len is not None:
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
)
|
162 |
-
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
166 |
logging.info(
|
|
|
31 |
)
|
32 |
|
33 |
|
34 |
+
def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
ds_hash = str(
|
36 |
md5(
|
37 |
(
|
|
|
48 |
)
|
49 |
|
50 |
if any(prepared_ds_path.glob("*")):
|
51 |
+
logging.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
52 |
dataset = load_from_disk(str(prepared_ds_path))
|
53 |
logging.info("Prepared dataset loaded from disk...")
|
54 |
else:
|
|
|
147 |
)
|
148 |
dataset.save_to_disk(prepared_ds_path)
|
149 |
|
150 |
+
return dataset
|
151 |
+
|
152 |
+
|
153 |
+
def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
154 |
+
max_packed_sequence_len = (
|
155 |
+
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
156 |
+
)
|
157 |
+
max_packed_sequence_len = min(
|
158 |
+
max_packed_sequence_len, cfg.sequence_len
|
159 |
+
) # make sure we don't accidentally set it larger than sequence_len
|
160 |
+
|
161 |
if cfg.max_packed_sequence_len is not None:
|
162 |
+
# see if we can go ahead and load the stacked dataset
|
163 |
+
|
164 |
+
ds_hash = str(
|
165 |
+
md5(
|
166 |
+
(
|
167 |
+
str(cfg.sequence_len)
|
168 |
+
+ "@"
|
169 |
+
+ str(max_packed_sequence_len)
|
170 |
+
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
|
171 |
+
).encode("utf-8")
|
172 |
+
).hexdigest()
|
173 |
+
)
|
174 |
+
prepared_ds_path = (
|
175 |
+
Path(cfg.dataset_prepared_path) / ds_hash
|
176 |
+
if cfg.dataset_prepared_path
|
177 |
+
else Path(default_dataset_prepared_path) / ds_hash
|
178 |
+
)
|
179 |
+
|
180 |
+
if any(prepared_ds_path.glob("*")):
|
181 |
+
logging.info(
|
182 |
+
f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
|
183 |
+
)
|
184 |
+
dataset = load_from_disk(str(prepared_ds_path))
|
185 |
+
logging.info("Prepared packed dataset loaded from disk...")
|
186 |
+
else:
|
187 |
+
dataset = load_tokenized_prepared_datasets(
|
188 |
+
tokenizer, cfg, default_dataset_prepared_path
|
189 |
+
)
|
190 |
+
|
191 |
+
constant_len_dataset = ConstantLengthDataset(
|
192 |
+
tokenizer,
|
193 |
+
[dataset],
|
194 |
+
seq_length=max_packed_sequence_len,
|
195 |
+
)
|
196 |
+
logging.info(
|
197 |
+
f"packing master dataset to len: {cfg.max_packed_sequence_len}"
|
198 |
+
)
|
199 |
+
dataset = Dataset.from_list([_ for _ in constant_len_dataset])
|
200 |
+
|
201 |
+
if cfg.local_rank == 0:
|
202 |
+
logging.info(
|
203 |
+
f"Saving packed prepared dataset to disk... {prepared_ds_path}"
|
204 |
+
)
|
205 |
+
dataset.save_to_disk(prepared_ds_path)
|
206 |
+
else:
|
207 |
+
dataset = load_tokenized_prepared_datasets(
|
208 |
+
tokenizer, cfg, default_dataset_prepared_path
|
209 |
)
|
210 |
+
|
211 |
+
# filter out bad data
|
212 |
+
dataset = Dataset.from_list(
|
213 |
+
[
|
214 |
+
d
|
215 |
+
for d in dataset
|
216 |
+
if len(d["input_ids"]) > cfg.sequence_len
|
217 |
+
and len(d["input_ids"]) > 0
|
218 |
+
and len(d["input_ids"]) == len(d["attention_mask"])
|
219 |
+
and len(d["input_ids"]) == len(d["labels"])
|
220 |
+
]
|
221 |
+
)
|
222 |
|
223 |
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
224 |
logging.info(
|
src/axolotl/utils/models.py
CHANGED
@@ -181,6 +181,8 @@ def load_model(
|
|
181 |
for k, v in cfg.tokens.items():
|
182 |
tokenizer.add_special_tokens({k: v})
|
183 |
|
|
|
|
|
184 |
if cfg.adapter and load_in_8bit and not cfg.load_4bit:
|
185 |
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
186 |
model = prepare_model_for_int8_training(model)
|
|
|
181 |
for k, v in cfg.tokens.items():
|
182 |
tokenizer.add_special_tokens({k: v})
|
183 |
|
184 |
+
model.resize_token_embeddings(len(tokenizer))
|
185 |
+
|
186 |
if cfg.adapter and load_in_8bit and not cfg.load_4bit:
|
187 |
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
188 |
model = prepare_model_for_int8_training(model)
|