address PR feedback
Browse files- examples/pythia-12b/README.md +1 -1
- examples/pythia-12b/config.yml +2 -2
- scripts/finetune.py +4 -1
- src/axolotl/utils/data.py +2 -2
- src/axolotl/utils/trainer.py +0 -2
examples/pythia-12b/README.md
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
#
|
2 |
|
3 |
- Single-GPU A100 only (?)
|
4 |
|
|
|
1 |
+
# Pythia 12B
|
2 |
|
3 |
- Single-GPU A100 only (?)
|
4 |
|
examples/pythia-12b/config.yml
CHANGED
@@ -22,7 +22,7 @@ lora_dropout: 0.0
|
|
22 |
lora_target_modules:
|
23 |
lora_target_linear: true
|
24 |
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
25 |
-
wandb_project:
|
26 |
wandb_watch:
|
27 |
wandb_run_id:
|
28 |
wandb_log_model:
|
@@ -45,5 +45,5 @@ resume_from_checkpoint:
|
|
45 |
local_rank:
|
46 |
gradient_checkpointing: true
|
47 |
fsdp:
|
48 |
-
|
49 |
collator_pad_to_longest: true
|
|
|
22 |
lora_target_modules:
|
23 |
lora_target_linear: true
|
24 |
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
25 |
+
wandb_project:
|
26 |
wandb_watch:
|
27 |
wandb_run_id:
|
28 |
wandb_log_model:
|
|
|
45 |
local_rank:
|
46 |
gradient_checkpointing: true
|
47 |
fsdp:
|
48 |
+
fsdp_config:
|
49 |
collator_pad_to_longest: true
|
scripts/finetune.py
CHANGED
@@ -208,7 +208,10 @@ def train(
|
|
208 |
)
|
209 |
else:
|
210 |
train_dataset = load_pretraining_dataset(
|
211 |
-
cfg.pretraining_dataset,
|
|
|
|
|
|
|
212 |
)
|
213 |
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
214 |
train_dataset = train_dataset.with_format("torch")
|
|
|
208 |
)
|
209 |
else:
|
210 |
train_dataset = load_pretraining_dataset(
|
211 |
+
cfg.pretraining_dataset,
|
212 |
+
tokenizer,
|
213 |
+
max_tokens=cfg.sequence_len,
|
214 |
+
seed=cfg.seed,
|
215 |
)
|
216 |
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
217 |
train_dataset = train_dataset.with_format("torch")
|
src/axolotl/utils/data.py
CHANGED
@@ -505,10 +505,10 @@ def encode_pretraining(tokenizer, max_tokens, examples):
|
|
505 |
return ret
|
506 |
|
507 |
|
508 |
-
def load_pretraining_dataset(path, tokenizer, max_tokens=2048):
|
509 |
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
510 |
dataset = load_dataset(path, streaming=True, split="train")
|
511 |
-
dataset = dataset.shuffle(seed=
|
512 |
# TODO dynamically figure out which columns/features to remove
|
513 |
dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
|
514 |
return dataset
|
|
|
505 |
return ret
|
506 |
|
507 |
|
508 |
+
def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
|
509 |
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
510 |
dataset = load_dataset(path, streaming=True, split="train")
|
511 |
+
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
|
512 |
# TODO dynamically figure out which columns/features to remove
|
513 |
dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
|
514 |
return dataset
|
src/axolotl/utils/trainer.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
"""Module containing the Trainer class and related functions"""
|
2 |
|
3 |
import importlib
|
4 |
-
import logging
|
5 |
import math
|
6 |
import os
|
7 |
import sys
|
@@ -232,7 +231,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
232 |
callbacks.append(SavePeftModelCallback)
|
233 |
|
234 |
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
|
235 |
-
logging.info("Setting up SaveBetterTransformerModelCallback.")
|
236 |
callbacks.append(SaveBetterTransformerModelCallback)
|
237 |
|
238 |
data_collator_kwargs = {
|
|
|
1 |
"""Module containing the Trainer class and related functions"""
|
2 |
|
3 |
import importlib
|
|
|
4 |
import math
|
5 |
import os
|
6 |
import sys
|
|
|
231 |
callbacks.append(SavePeftModelCallback)
|
232 |
|
233 |
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
|
|
|
234 |
callbacks.append(SaveBetterTransformerModelCallback)
|
235 |
|
236 |
data_collator_kwargs = {
|