casperhansen
commited on
Commit
•
e50ab07
1
Parent(s):
05bd6f1
Create preprocess CLI (#785)
Browse files* Create preprocess CLI
* Print prompt template if debugging
* Add print for unsupported prompters
* Formatting
* Formatting
* Refactor variables
* Formatting
* Formatting
* Formatting
* Formatting
- README.md +31 -23
- scripts/finetune.py +0 -2
- src/axolotl/cli/__init__.py +7 -1
- src/axolotl/cli/preprocess.py +53 -0
- src/axolotl/cli/train.py +0 -13
- src/axolotl/common/cli.py +12 -1
- src/axolotl/prompt_tokenizers.py +1 -0
- src/axolotl/prompters.py +69 -14
- src/axolotl/utils/data.py +181 -136
README.md
CHANGED
@@ -32,7 +32,6 @@ Features:
|
|
32 |
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
|
33 |
- [Config](#config)
|
34 |
- [Train](#train)
|
35 |
-
- [Training w/ Deepspeed](#training-with-deepspeed)
|
36 |
- [Inference](#inference)
|
37 |
- [Merge LORA to Base](#merge-lora-to-base)
|
38 |
- [Common Errors](#common-errors-)
|
@@ -824,14 +823,41 @@ Run
|
|
824 |
accelerate launch -m axolotl.cli.train your_config.yml
|
825 |
```
|
826 |
|
827 |
-
####
|
|
|
|
|
|
|
|
|
|
|
|
|
828 |
|
829 |
-
You can optionally pre-tokenize dataset with the following before finetuning:
|
830 |
```bash
|
831 |
-
|
832 |
```
|
833 |
|
834 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
835 |
|
836 |
- llama FSDP
|
837 |
```yaml
|
@@ -856,24 +882,6 @@ wandb_run_id:
|
|
856 |
wandb_log_model:
|
857 |
```
|
858 |
|
859 |
-
### Training with Deepspeed
|
860 |
-
|
861 |
-
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
|
862 |
-
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
|
863 |
-
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
|
864 |
-
|
865 |
-
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
|
866 |
-
|
867 |
-
```shell
|
868 |
-
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
|
869 |
-
```
|
870 |
-
|
871 |
-
or
|
872 |
-
|
873 |
-
```yaml
|
874 |
-
deepspeed: deepspeed/zero1.json
|
875 |
-
```
|
876 |
-
|
877 |
### Inference
|
878 |
|
879 |
Pass the appropriate flag to the train command:
|
|
|
32 |
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
|
33 |
- [Config](#config)
|
34 |
- [Train](#train)
|
|
|
35 |
- [Inference](#inference)
|
36 |
- [Merge LORA to Base](#merge-lora-to-base)
|
37 |
- [Common Errors](#common-errors-)
|
|
|
823 |
accelerate launch -m axolotl.cli.train your_config.yml
|
824 |
```
|
825 |
|
826 |
+
#### Preprocess dataset
|
827 |
+
|
828 |
+
You can optionally pre-tokenize dataset with the following before finetuning.
|
829 |
+
This is recommended for large datasets.
|
830 |
+
|
831 |
+
- Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface.
|
832 |
+
- Use `--debug` to see preprocessed examples.
|
833 |
|
|
|
834 |
```bash
|
835 |
+
python -m axolotl.cli.preprocess your_config.yml
|
836 |
```
|
837 |
|
838 |
+
#### Multi-GPU
|
839 |
+
|
840 |
+
Below are the options available in axolotl for training with multiple GPUs. Note that DeepSpeed
|
841 |
+
is the recommended multi-GPU option currently because FSDP may experience
|
842 |
+
[loss instability](https://github.com/huggingface/transformers/issues/26498).
|
843 |
+
|
844 |
+
##### DeepSpeed
|
845 |
+
|
846 |
+
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
|
847 |
+
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
|
848 |
+
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
|
849 |
+
|
850 |
+
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
|
851 |
+
|
852 |
+
```yaml
|
853 |
+
deepspeed: deepspeed/zero1.json
|
854 |
+
```
|
855 |
+
|
856 |
+
```shell
|
857 |
+
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
|
858 |
+
```
|
859 |
+
|
860 |
+
##### FSDP
|
861 |
|
862 |
- llama FSDP
|
863 |
```yaml
|
|
|
882 |
wandb_log_model:
|
883 |
```
|
884 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
885 |
### Inference
|
886 |
|
887 |
Pass the appropriate flag to the train command:
|
scripts/finetune.py
CHANGED
@@ -45,8 +45,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|
45 |
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
46 |
else:
|
47 |
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
48 |
-
if parsed_cli_args.prepare_ds_only:
|
49 |
-
return
|
50 |
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
51 |
|
52 |
|
|
|
45 |
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
46 |
else:
|
47 |
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
|
|
|
|
48 |
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
49 |
|
50 |
|
src/axolotl/cli/__init__.py
CHANGED
@@ -222,7 +222,9 @@ def load_datasets(
|
|
222 |
) -> TrainDatasetMeta:
|
223 |
tokenizer = load_tokenizer(cfg)
|
224 |
|
225 |
-
train_dataset, eval_dataset, total_num_steps = prepare_dataset(
|
|
|
|
|
226 |
|
227 |
if cli_args.debug or cfg.debug:
|
228 |
LOG.info("check_dataset_labels...")
|
@@ -238,6 +240,10 @@ def load_datasets(
|
|
238 |
text_only=cli_args.debug_text_only,
|
239 |
)
|
240 |
|
|
|
|
|
|
|
|
|
241 |
return TrainDatasetMeta(
|
242 |
train_dataset=train_dataset,
|
243 |
eval_dataset=eval_dataset,
|
|
|
222 |
) -> TrainDatasetMeta:
|
223 |
tokenizer = load_tokenizer(cfg)
|
224 |
|
225 |
+
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
|
226 |
+
cfg, tokenizer
|
227 |
+
)
|
228 |
|
229 |
if cli_args.debug or cfg.debug:
|
230 |
LOG.info("check_dataset_labels...")
|
|
|
240 |
text_only=cli_args.debug_text_only,
|
241 |
)
|
242 |
|
243 |
+
LOG.info("printing prompters...")
|
244 |
+
for prompter in prompters:
|
245 |
+
LOG.info(prompter)
|
246 |
+
|
247 |
return TrainDatasetMeta(
|
248 |
train_dataset=train_dataset,
|
249 |
eval_dataset=eval_dataset,
|
src/axolotl/cli/preprocess.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
CLI to run training on a model
|
3 |
+
"""
|
4 |
+
import logging
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import fire
|
8 |
+
import transformers
|
9 |
+
from colorama import Fore
|
10 |
+
|
11 |
+
from axolotl.cli import (
|
12 |
+
check_accelerate_default_config,
|
13 |
+
check_user_token,
|
14 |
+
load_cfg,
|
15 |
+
load_datasets,
|
16 |
+
print_axolotl_text_art,
|
17 |
+
)
|
18 |
+
from axolotl.common.cli import PreprocessCliArgs
|
19 |
+
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
20 |
+
|
21 |
+
LOG = logging.getLogger("axolotl.cli.preprocess")
|
22 |
+
|
23 |
+
|
24 |
+
def do_cli(config: Path = Path("examples/"), **kwargs):
|
25 |
+
# pylint: disable=duplicate-code
|
26 |
+
print_axolotl_text_art()
|
27 |
+
parsed_cfg = load_cfg(config, **kwargs)
|
28 |
+
check_accelerate_default_config()
|
29 |
+
check_user_token()
|
30 |
+
parser = transformers.HfArgumentParser((PreprocessCliArgs))
|
31 |
+
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
32 |
+
return_remaining_strings=True
|
33 |
+
)
|
34 |
+
if not parsed_cfg.dataset_prepared_path:
|
35 |
+
msg = (
|
36 |
+
Fore.RED
|
37 |
+
+ "preprocess CLI called without dataset_prepared_path set, "
|
38 |
+
+ f"using default path: {DEFAULT_DATASET_PREPARED_PATH}"
|
39 |
+
+ Fore.RESET
|
40 |
+
)
|
41 |
+
LOG.warning(msg)
|
42 |
+
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
43 |
+
|
44 |
+
_ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
45 |
+
LOG.info(
|
46 |
+
Fore.GREEN
|
47 |
+
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
|
48 |
+
+ Fore.RESET
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == "__main__":
|
53 |
+
fire.Fire(do_cli)
|
src/axolotl/cli/train.py
CHANGED
@@ -6,7 +6,6 @@ from pathlib import Path
|
|
6 |
|
7 |
import fire
|
8 |
import transformers
|
9 |
-
from colorama import Fore
|
10 |
|
11 |
from axolotl.cli import (
|
12 |
check_accelerate_default_config,
|
@@ -16,7 +15,6 @@ from axolotl.cli import (
|
|
16 |
print_axolotl_text_art,
|
17 |
)
|
18 |
from axolotl.common.cli import TrainerCliArgs
|
19 |
-
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
20 |
from axolotl.train import train
|
21 |
|
22 |
LOG = logging.getLogger("axolotl.cli.train")
|
@@ -32,18 +30,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|
32 |
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
33 |
return_remaining_strings=True
|
34 |
)
|
35 |
-
if parsed_cli_args.prepare_ds_only and not parsed_cfg.dataset_prepared_path:
|
36 |
-
msg = (
|
37 |
-
Fore.RED
|
38 |
-
+ "--prepare_ds_only called without dataset_prepared_path set."
|
39 |
-
+ Fore.RESET
|
40 |
-
)
|
41 |
-
LOG.warning(msg)
|
42 |
-
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
43 |
-
|
44 |
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
45 |
-
if parsed_cli_args.prepare_ds_only:
|
46 |
-
return
|
47 |
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
48 |
|
49 |
|
|
|
6 |
|
7 |
import fire
|
8 |
import transformers
|
|
|
9 |
|
10 |
from axolotl.cli import (
|
11 |
check_accelerate_default_config,
|
|
|
15 |
print_axolotl_text_art,
|
16 |
)
|
17 |
from axolotl.common.cli import TrainerCliArgs
|
|
|
18 |
from axolotl.train import train
|
19 |
|
20 |
LOG = logging.getLogger("axolotl.cli.train")
|
|
|
30 |
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
31 |
return_remaining_strings=True
|
32 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
|
|
|
|
34 |
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
35 |
|
36 |
|
src/axolotl/common/cli.py
CHANGED
@@ -25,11 +25,22 @@ class TrainerCliArgs:
|
|
25 |
debug_num_examples: int = field(default=5)
|
26 |
inference: bool = field(default=False)
|
27 |
merge_lora: bool = field(default=False)
|
28 |
-
prepare_ds_only: bool = field(default=False)
|
29 |
prompter: Optional[str] = field(default=None)
|
30 |
shard: bool = field(default=False)
|
31 |
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
def load_model_and_tokenizer(
|
34 |
*,
|
35 |
cfg: DictDefault,
|
|
|
25 |
debug_num_examples: int = field(default=5)
|
26 |
inference: bool = field(default=False)
|
27 |
merge_lora: bool = field(default=False)
|
|
|
28 |
prompter: Optional[str] = field(default=None)
|
29 |
shard: bool = field(default=False)
|
30 |
|
31 |
|
32 |
+
@dataclass
|
33 |
+
class PreprocessCliArgs:
|
34 |
+
"""
|
35 |
+
dataclass representing arguments for preprocessing only
|
36 |
+
"""
|
37 |
+
|
38 |
+
debug: bool = field(default=False)
|
39 |
+
debug_text_only: bool = field(default=False)
|
40 |
+
debug_num_examples: int = field(default=1)
|
41 |
+
prompter: Optional[str] = field(default=None)
|
42 |
+
|
43 |
+
|
44 |
def load_model_and_tokenizer(
|
45 |
*,
|
46 |
cfg: DictDefault,
|
src/axolotl/prompt_tokenizers.py
CHANGED
@@ -245,6 +245,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
245 |
raise NotImplementedError
|
246 |
|
247 |
def tokenize_prompt(self, prompt):
|
|
|
248 |
(
|
249 |
instruction,
|
250 |
input, # pylint: disable=redefined-builtin
|
|
|
245 |
raise NotImplementedError
|
246 |
|
247 |
def tokenize_prompt(self, prompt):
|
248 |
+
# pylint: disable=duplicate-code
|
249 |
(
|
250 |
instruction,
|
251 |
input, # pylint: disable=redefined-builtin
|
src/axolotl/prompters.py
CHANGED
@@ -4,10 +4,12 @@ import logging
|
|
4 |
from enum import Enum
|
5 |
from typing import Generator, Optional, Union
|
6 |
|
|
|
7 |
from fastchat.conversation import Conversation, get_conv_template
|
8 |
|
9 |
LOG = logging.getLogger("axolotl")
|
10 |
IGNORE_TOKEN_ID = -100
|
|
|
11 |
|
12 |
|
13 |
class PromptStyle(Enum):
|
@@ -55,20 +57,15 @@ class AlpacaPrompter:
|
|
55 |
)
|
56 |
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
|
57 |
|
58 |
-
def
|
59 |
-
self,
|
60 |
-
instruction: str,
|
61 |
-
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
62 |
-
output: Union[None, str] = None,
|
63 |
-
) -> Generator[str, None, None]:
|
64 |
# returns the full prompt from instruction and optional input
|
65 |
# if a label (=response, =output) is provided, it's also appended.
|
66 |
-
if
|
67 |
res = (
|
68 |
self.system_format.format(system=self.system_prompt)
|
69 |
if self.system_prompt
|
70 |
else ""
|
71 |
-
) + self.turn_format.format(instruction=instruction, input=
|
72 |
else:
|
73 |
res = (
|
74 |
self.system_format.format(system=self.system_no_input_prompt)
|
@@ -77,7 +74,21 @@ class AlpacaPrompter:
|
|
77 |
) + self.turn_no_input_format.format(instruction=instruction)
|
78 |
if output:
|
79 |
res = f"{res}{output}"
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
|
83 |
class UnpromptedPrompter(AlpacaPrompter):
|
@@ -191,14 +202,14 @@ class ReflectAlpacaPrompter:
|
|
191 |
)
|
192 |
self.response_split = "ASSISTANT:"
|
193 |
|
194 |
-
def
|
195 |
self,
|
196 |
instruction: str,
|
197 |
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
198 |
output: Union[None, str] = None,
|
199 |
reflection: Union[None, str] = None,
|
200 |
corrected: Union[None, str] = None,
|
201 |
-
)
|
202 |
# returns the full prompt from instruction and optional input
|
203 |
# if a label (=response, =output) is provided, it's also appended.
|
204 |
if input:
|
@@ -212,7 +223,30 @@ class ReflectAlpacaPrompter:
|
|
212 |
corrected=corrected,
|
213 |
)
|
214 |
res = f"{res}{label}"
|
215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
|
217 |
|
218 |
SHAREGPT_ASSERTION_FAILED_ROLE = (
|
@@ -247,7 +281,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
|
247 |
if role_key_model:
|
248 |
self.role_key_model = role_key_model
|
249 |
|
250 |
-
def
|
251 |
if len(source) < 2:
|
252 |
# If there isn't a back and forth conversation, ignore it
|
253 |
# also happens on the data splitting leaving empty conversations
|
@@ -282,11 +316,20 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
|
282 |
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
283 |
conv.append_message(role, sentence["value"])
|
284 |
|
285 |
-
|
|
|
|
|
|
|
|
|
|
|
286 |
if part[0] and not part[1]:
|
287 |
LOG.warning(f"role with empty message: {part[0]}")
|
288 |
yield part
|
289 |
|
|
|
|
|
|
|
|
|
290 |
|
291 |
class ShareGPTPrompterV2(ShareGPTPrompter):
|
292 |
"""
|
@@ -304,3 +347,15 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
|
|
304 |
role_key_human=role_key_human,
|
305 |
role_key_model=role_key_model,
|
306 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from enum import Enum
|
5 |
from typing import Generator, Optional, Union
|
6 |
|
7 |
+
from colorama import Fore
|
8 |
from fastchat.conversation import Conversation, get_conv_template
|
9 |
|
10 |
LOG = logging.getLogger("axolotl")
|
11 |
IGNORE_TOKEN_ID = -100
|
12 |
+
REPR_TEMPLATE = "\n<start>\n" + Fore.CYAN + "{full_prompt}" + Fore.RESET + "\n<end>\n"
|
13 |
|
14 |
|
15 |
class PromptStyle(Enum):
|
|
|
57 |
)
|
58 |
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
|
59 |
|
60 |
+
def _build_result(self, instruction, input_text, output):
|
|
|
|
|
|
|
|
|
|
|
61 |
# returns the full prompt from instruction and optional input
|
62 |
# if a label (=response, =output) is provided, it's also appended.
|
63 |
+
if input_text:
|
64 |
res = (
|
65 |
self.system_format.format(system=self.system_prompt)
|
66 |
if self.system_prompt
|
67 |
else ""
|
68 |
+
) + self.turn_format.format(instruction=instruction, input=input_text)
|
69 |
else:
|
70 |
res = (
|
71 |
self.system_format.format(system=self.system_no_input_prompt)
|
|
|
74 |
) + self.turn_no_input_format.format(instruction=instruction)
|
75 |
if output:
|
76 |
res = f"{res}{output}"
|
77 |
+
|
78 |
+
return res
|
79 |
+
|
80 |
+
def build_prompt(
|
81 |
+
self,
|
82 |
+
instruction: str,
|
83 |
+
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
84 |
+
output: Union[None, str] = None,
|
85 |
+
) -> Generator[str, None, None]:
|
86 |
+
yield self._build_result(instruction, input, output)
|
87 |
+
|
88 |
+
def __repr__(self) -> str:
|
89 |
+
return REPR_TEMPLATE.format(
|
90 |
+
full_prompt=self._build_result("{instruction}", "{input}", "{output}")
|
91 |
+
)
|
92 |
|
93 |
|
94 |
class UnpromptedPrompter(AlpacaPrompter):
|
|
|
202 |
)
|
203 |
self.response_split = "ASSISTANT:"
|
204 |
|
205 |
+
def _build_result(
|
206 |
self,
|
207 |
instruction: str,
|
208 |
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
209 |
output: Union[None, str] = None,
|
210 |
reflection: Union[None, str] = None,
|
211 |
corrected: Union[None, str] = None,
|
212 |
+
):
|
213 |
# returns the full prompt from instruction and optional input
|
214 |
# if a label (=response, =output) is provided, it's also appended.
|
215 |
if input:
|
|
|
223 |
corrected=corrected,
|
224 |
)
|
225 |
res = f"{res}{label}"
|
226 |
+
|
227 |
+
return res
|
228 |
+
|
229 |
+
def build_prompt(
|
230 |
+
self,
|
231 |
+
instruction: str,
|
232 |
+
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
233 |
+
output: Union[None, str] = None,
|
234 |
+
reflection: Union[None, str] = None,
|
235 |
+
corrected: Union[None, str] = None,
|
236 |
+
) -> Generator[str, None, None]:
|
237 |
+
# pylint: disable=duplicate-code
|
238 |
+
yield self._build_result(
|
239 |
+
instruction,
|
240 |
+
input,
|
241 |
+
output,
|
242 |
+
reflection,
|
243 |
+
corrected,
|
244 |
+
)
|
245 |
+
|
246 |
+
def __repr__(self) -> str:
|
247 |
+
return REPR_TEMPLATE.format(
|
248 |
+
full_prompt=self._build_result("{instruction}", "{input}", "{output}")
|
249 |
+
)
|
250 |
|
251 |
|
252 |
SHAREGPT_ASSERTION_FAILED_ROLE = (
|
|
|
281 |
if role_key_model:
|
282 |
self.role_key_model = role_key_model
|
283 |
|
284 |
+
def _build_result(self, source):
|
285 |
if len(source) < 2:
|
286 |
# If there isn't a back and forth conversation, ignore it
|
287 |
# also happens on the data splitting leaving empty conversations
|
|
|
316 |
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
317 |
conv.append_message(role, sentence["value"])
|
318 |
|
319 |
+
return conv.get_turns()
|
320 |
+
|
321 |
+
def build_prompt(self, source) -> Generator[str, None, None]:
|
322 |
+
turns = self._build_result(source)
|
323 |
+
|
324 |
+
for part in turns:
|
325 |
if part[0] and not part[1]:
|
326 |
LOG.warning(f"role with empty message: {part[0]}")
|
327 |
yield part
|
328 |
|
329 |
+
def __repr__(self) -> str:
|
330 |
+
turns = self._build_result([{"from": "{from}", "value": "{value}"}])
|
331 |
+
return "\n".join([REPR_TEMPLATE.format(full_prompt=part) for part in turns])
|
332 |
+
|
333 |
|
334 |
class ShareGPTPrompterV2(ShareGPTPrompter):
|
335 |
"""
|
|
|
347 |
role_key_human=role_key_human,
|
348 |
role_key_model=role_key_model,
|
349 |
)
|
350 |
+
|
351 |
+
|
352 |
+
class UnsupportedPrompter:
|
353 |
+
"""
|
354 |
+
A dummy class for custom prompters
|
355 |
+
"""
|
356 |
+
|
357 |
+
def __init__(self) -> None:
|
358 |
+
pass
|
359 |
+
|
360 |
+
def __repr__(self):
|
361 |
+
return "Pre-tokenized or custom dataset types are unsupported for logging"
|
src/axolotl/utils/data.py
CHANGED
@@ -3,7 +3,7 @@ import functools
|
|
3 |
import hashlib
|
4 |
import logging
|
5 |
from pathlib import Path
|
6 |
-
from typing import Dict, List, Tuple, Union
|
7 |
|
8 |
import torch
|
9 |
from datasets import (
|
@@ -36,6 +36,7 @@ from axolotl.prompters import (
|
|
36 |
MultipleChoiceExplainPrompter,
|
37 |
ReflectAlpacaPrompter,
|
38 |
SummarizeTLDRPrompter,
|
|
|
39 |
)
|
40 |
from axolotl.utils.dict import DictDefault
|
41 |
from axolotl.utils.distributed import is_main_process, zero_first
|
@@ -55,9 +56,10 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
|
55 |
|
56 |
|
57 |
def prepare_dataset(cfg, tokenizer):
|
|
|
58 |
if not cfg.pretraining_dataset:
|
59 |
with zero_first(is_main_process()):
|
60 |
-
train_dataset, eval_dataset = load_prepare_datasets(
|
61 |
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
62 |
)
|
63 |
else:
|
@@ -70,7 +72,7 @@ def prepare_dataset(cfg, tokenizer):
|
|
70 |
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
71 |
train_dataset = train_dataset.with_format("torch")
|
72 |
eval_dataset = None
|
73 |
-
return train_dataset, eval_dataset, cfg.max_steps
|
74 |
|
75 |
with zero_first(is_main_process()):
|
76 |
train_dataset, eval_dataset = process_datasets_for_packing(
|
@@ -83,7 +85,7 @@ def prepare_dataset(cfg, tokenizer):
|
|
83 |
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
84 |
else:
|
85 |
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
86 |
-
return train_dataset, eval_dataset, total_num_steps
|
87 |
|
88 |
|
89 |
def load_tokenized_prepared_datasets(
|
@@ -109,6 +111,7 @@ def load_tokenized_prepared_datasets(
|
|
109 |
else Path(default_dataset_prepared_path) / ds_hash
|
110 |
)
|
111 |
dataset = None
|
|
|
112 |
use_auth_token = cfg.hf_use_auth_token
|
113 |
try:
|
114 |
if cfg.push_dataset_to_hub:
|
@@ -147,13 +150,13 @@ def load_tokenized_prepared_datasets(
|
|
147 |
yield dataset
|
148 |
|
149 |
# pylint: disable=invalid-name
|
150 |
-
for
|
151 |
ds: Union[Dataset, DatasetDict] = None
|
152 |
ds_from_hub = False
|
153 |
try:
|
154 |
load_dataset(
|
155 |
-
|
156 |
-
name=
|
157 |
streaming=True,
|
158 |
token=use_auth_token,
|
159 |
)
|
@@ -162,33 +165,33 @@ def load_tokenized_prepared_datasets(
|
|
162 |
pass
|
163 |
|
164 |
# prefer local dataset, even if hub exists
|
165 |
-
local_path = Path(
|
166 |
if local_path.exists():
|
167 |
if local_path.is_dir():
|
168 |
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
|
169 |
ds = load_dataset(
|
170 |
-
|
171 |
-
name=
|
172 |
-
data_files=
|
173 |
streaming=False,
|
174 |
split=None,
|
175 |
)
|
176 |
elif local_path.is_file():
|
177 |
ds_type = "json"
|
178 |
-
if
|
179 |
-
ds_type =
|
180 |
-
elif ".parquet" in
|
181 |
ds_type = "parquet"
|
182 |
-
elif ".arrow" in
|
183 |
ds_type = "arrow"
|
184 |
-
elif ".csv" in
|
185 |
ds_type = "csv"
|
186 |
-
elif ".txt" in
|
187 |
ds_type = "text"
|
188 |
ds = load_dataset(
|
189 |
ds_type,
|
190 |
-
name=
|
191 |
-
data_files=
|
192 |
streaming=False,
|
193 |
split=None,
|
194 |
)
|
@@ -198,25 +201,25 @@ def load_tokenized_prepared_datasets(
|
|
198 |
)
|
199 |
elif ds_from_hub:
|
200 |
ds = load_dataset(
|
201 |
-
|
202 |
-
name=
|
203 |
streaming=False,
|
204 |
-
data_files=
|
205 |
token=use_auth_token,
|
206 |
)
|
207 |
else:
|
208 |
-
if isinstance(
|
209 |
fp = hf_hub_download(
|
210 |
-
repo_id=
|
211 |
repo_type="dataset",
|
212 |
-
filename=
|
213 |
)
|
214 |
-
elif isinstance(
|
215 |
fp = []
|
216 |
-
for file in
|
217 |
fp.append(
|
218 |
hf_hub_download(
|
219 |
-
repo_id=
|
220 |
repo_type="dataset",
|
221 |
filename=file,
|
222 |
)
|
@@ -226,21 +229,27 @@ def load_tokenized_prepared_datasets(
|
|
226 |
"data_files must be either a string or list of strings"
|
227 |
)
|
228 |
ds = load_dataset(
|
229 |
-
"json",
|
|
|
|
|
|
|
|
|
230 |
)
|
231 |
if not ds:
|
232 |
raise ValueError("unhandled dataset load")
|
233 |
# support for using a subset of the data
|
234 |
-
if
|
235 |
if "train" in ds:
|
236 |
ds = ds.shuffle(seed=seed)["train"].shard(
|
237 |
-
num_shards=
|
238 |
)
|
239 |
else:
|
240 |
-
ds = ds.shuffle(seed=seed).shard(
|
|
|
|
|
241 |
|
242 |
d_base_type = d_prompt_style = None
|
243 |
-
d_type =
|
244 |
if isinstance(d_type, str):
|
245 |
d_type_split = d_type.split(":")
|
246 |
d_base_type = d_type_split[0]
|
@@ -249,108 +258,26 @@ def load_tokenized_prepared_datasets(
|
|
249 |
ds = ds["train"]
|
250 |
elif (
|
251 |
isinstance(ds, DatasetDict)
|
252 |
-
and
|
253 |
-
and
|
254 |
):
|
255 |
-
ds = ds[
|
256 |
elif isinstance(ds, DatasetDict):
|
257 |
raise ValueError(
|
258 |
-
f"no train split found for dataset {
|
259 |
-
)
|
260 |
-
if (
|
261 |
-
"input_ids" in ds.features
|
262 |
-
and "attention_mask" in ds.features
|
263 |
-
and "labels" in ds.features
|
264 |
-
):
|
265 |
-
# dataset is already tokenized, just drop it straight in
|
266 |
-
datasets.append(ds)
|
267 |
-
elif isinstance(d.type, DictDefault):
|
268 |
-
ds_strategy = load("user_defined", tokenizer, cfg, d.type.to_dict())
|
269 |
-
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
270 |
-
datasets.append(ds_wrapper)
|
271 |
-
elif ds_strategy := load(d.type, tokenizer, cfg, d):
|
272 |
-
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
273 |
-
datasets.append(ds_wrapper)
|
274 |
-
elif d_base_type == "alpaca":
|
275 |
-
ds_strategy = AlpacaPromptTokenizingStrategy(
|
276 |
-
AlpacaPrompter(d_prompt_style),
|
277 |
-
tokenizer,
|
278 |
-
cfg.train_on_inputs,
|
279 |
-
cfg.sequence_len,
|
280 |
-
)
|
281 |
-
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
282 |
-
datasets.append(ds_wrapper)
|
283 |
-
elif d_base_type == "explainchoice":
|
284 |
-
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
285 |
-
MultipleChoiceExplainPrompter(d_prompt_style),
|
286 |
-
tokenizer,
|
287 |
-
cfg.train_on_inputs,
|
288 |
-
cfg.sequence_len,
|
289 |
-
)
|
290 |
-
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
291 |
-
datasets.append(ds_wrapper)
|
292 |
-
elif d_base_type == "concisechoice":
|
293 |
-
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
294 |
-
MultipleChoiceConcisePrompter(d_prompt_style),
|
295 |
-
tokenizer,
|
296 |
-
cfg.train_on_inputs,
|
297 |
-
cfg.sequence_len,
|
298 |
-
)
|
299 |
-
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
300 |
-
datasets.append(ds_wrapper)
|
301 |
-
elif d_base_type == "summarizetldr":
|
302 |
-
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
|
303 |
-
SummarizeTLDRPrompter(d_prompt_style),
|
304 |
-
tokenizer,
|
305 |
-
cfg.train_on_inputs,
|
306 |
-
cfg.sequence_len,
|
307 |
-
)
|
308 |
-
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
309 |
-
datasets.append(ds_wrapper)
|
310 |
-
elif d_base_type == "jeopardy":
|
311 |
-
ds_strategy = JeopardyPromptTokenizingStrategy(
|
312 |
-
JeopardyPrompter(d_prompt_style),
|
313 |
-
tokenizer,
|
314 |
-
cfg.train_on_inputs,
|
315 |
-
cfg.sequence_len,
|
316 |
-
)
|
317 |
-
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
318 |
-
datasets.append(ds_wrapper)
|
319 |
-
elif d_base_type == "oasst":
|
320 |
-
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
321 |
-
AlpacaPrompter(d_prompt_style),
|
322 |
-
tokenizer,
|
323 |
-
cfg.train_on_inputs,
|
324 |
-
cfg.sequence_len,
|
325 |
-
)
|
326 |
-
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
327 |
-
datasets.append(ds_wrapper)
|
328 |
-
elif d_base_type == "gpteacher":
|
329 |
-
ds_strategy = GPTeacherPromptTokenizingStrategy(
|
330 |
-
GPTeacherPrompter(d_prompt_style),
|
331 |
-
tokenizer,
|
332 |
-
cfg.train_on_inputs,
|
333 |
-
cfg.sequence_len,
|
334 |
-
)
|
335 |
-
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
336 |
-
datasets.append(ds_wrapper)
|
337 |
-
elif d_base_type == "reflection":
|
338 |
-
ds_strategy = AlpacaReflectionPTStrategy(
|
339 |
-
ReflectAlpacaPrompter(d_prompt_style),
|
340 |
-
tokenizer,
|
341 |
-
cfg.train_on_inputs,
|
342 |
-
cfg.sequence_len,
|
343 |
-
)
|
344 |
-
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
345 |
-
datasets.append(ds_wrapper)
|
346 |
-
else:
|
347 |
-
suffix = ""
|
348 |
-
if ":load_" in d.type:
|
349 |
-
suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
|
350 |
-
LOG.error(f"unhandled prompt tokenization strategy: {d.type}. {suffix}")
|
351 |
-
raise ValueError(
|
352 |
-
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
|
353 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
354 |
LOG.info("merging datasets")
|
355 |
dataset = concatenate_datasets(datasets)
|
356 |
|
@@ -368,14 +295,14 @@ def load_tokenized_prepared_datasets(
|
|
368 |
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
369 |
)
|
370 |
|
371 |
-
return dataset
|
372 |
|
373 |
|
374 |
def load_prepare_datasets(
|
375 |
tokenizer: PreTrainedTokenizerBase,
|
376 |
cfg,
|
377 |
default_dataset_prepared_path,
|
378 |
-
) -> Tuple[Dataset, Dataset]:
|
379 |
max_packed_sequence_len = (
|
380 |
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
381 |
)
|
@@ -384,6 +311,7 @@ def load_prepare_datasets(
|
|
384 |
) # make sure we don't accidentally set it larger than sequence_len
|
385 |
|
386 |
tokenizer_name = tokenizer.__class__.__name__
|
|
|
387 |
if cfg.max_packed_sequence_len is not None:
|
388 |
# see if we can go ahead and load the stacked dataset
|
389 |
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
|
@@ -439,7 +367,7 @@ def load_prepare_datasets(
|
|
439 |
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
440 |
)
|
441 |
else:
|
442 |
-
dataset = load_tokenized_prepared_datasets(
|
443 |
tokenizer, cfg, default_dataset_prepared_path
|
444 |
)
|
445 |
|
@@ -481,7 +409,7 @@ def load_prepare_datasets(
|
|
481 |
private=True,
|
482 |
)
|
483 |
else:
|
484 |
-
dataset = load_tokenized_prepared_datasets(
|
485 |
tokenizer, cfg, default_dataset_prepared_path
|
486 |
)
|
487 |
|
@@ -532,7 +460,124 @@ def load_prepare_datasets(
|
|
532 |
train_dataset = dataset
|
533 |
eval_dataset = None
|
534 |
|
535 |
-
return train_dataset, eval_dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
536 |
|
537 |
|
538 |
def encode_pretraining(
|
|
|
3 |
import hashlib
|
4 |
import logging
|
5 |
from pathlib import Path
|
6 |
+
from typing import Any, Dict, List, Tuple, Union
|
7 |
|
8 |
import torch
|
9 |
from datasets import (
|
|
|
36 |
MultipleChoiceExplainPrompter,
|
37 |
ReflectAlpacaPrompter,
|
38 |
SummarizeTLDRPrompter,
|
39 |
+
UnsupportedPrompter,
|
40 |
)
|
41 |
from axolotl.utils.dict import DictDefault
|
42 |
from axolotl.utils.distributed import is_main_process, zero_first
|
|
|
56 |
|
57 |
|
58 |
def prepare_dataset(cfg, tokenizer):
|
59 |
+
prompters = []
|
60 |
if not cfg.pretraining_dataset:
|
61 |
with zero_first(is_main_process()):
|
62 |
+
train_dataset, eval_dataset, prompters = load_prepare_datasets(
|
63 |
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
64 |
)
|
65 |
else:
|
|
|
72 |
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
73 |
train_dataset = train_dataset.with_format("torch")
|
74 |
eval_dataset = None
|
75 |
+
return train_dataset, eval_dataset, cfg.max_steps, prompters
|
76 |
|
77 |
with zero_first(is_main_process()):
|
78 |
train_dataset, eval_dataset = process_datasets_for_packing(
|
|
|
85 |
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
86 |
else:
|
87 |
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
88 |
+
return train_dataset, eval_dataset, total_num_steps, prompters
|
89 |
|
90 |
|
91 |
def load_tokenized_prepared_datasets(
|
|
|
111 |
else Path(default_dataset_prepared_path) / ds_hash
|
112 |
)
|
113 |
dataset = None
|
114 |
+
prompters = []
|
115 |
use_auth_token = cfg.hf_use_auth_token
|
116 |
try:
|
117 |
if cfg.push_dataset_to_hub:
|
|
|
150 |
yield dataset
|
151 |
|
152 |
# pylint: disable=invalid-name
|
153 |
+
for config_dataset in for_d_in_datasets(cfg.datasets):
|
154 |
ds: Union[Dataset, DatasetDict] = None
|
155 |
ds_from_hub = False
|
156 |
try:
|
157 |
load_dataset(
|
158 |
+
config_dataset.path,
|
159 |
+
name=config_dataset.name,
|
160 |
streaming=True,
|
161 |
token=use_auth_token,
|
162 |
)
|
|
|
165 |
pass
|
166 |
|
167 |
# prefer local dataset, even if hub exists
|
168 |
+
local_path = Path(config_dataset.path)
|
169 |
if local_path.exists():
|
170 |
if local_path.is_dir():
|
171 |
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
|
172 |
ds = load_dataset(
|
173 |
+
config_dataset.path,
|
174 |
+
name=config_dataset.name,
|
175 |
+
data_files=config_dataset.data_files,
|
176 |
streaming=False,
|
177 |
split=None,
|
178 |
)
|
179 |
elif local_path.is_file():
|
180 |
ds_type = "json"
|
181 |
+
if config_dataset.ds_type:
|
182 |
+
ds_type = config_dataset.ds_type
|
183 |
+
elif ".parquet" in config_dataset.path:
|
184 |
ds_type = "parquet"
|
185 |
+
elif ".arrow" in config_dataset.path:
|
186 |
ds_type = "arrow"
|
187 |
+
elif ".csv" in config_dataset.path:
|
188 |
ds_type = "csv"
|
189 |
+
elif ".txt" in config_dataset.path:
|
190 |
ds_type = "text"
|
191 |
ds = load_dataset(
|
192 |
ds_type,
|
193 |
+
name=config_dataset.name,
|
194 |
+
data_files=config_dataset.path,
|
195 |
streaming=False,
|
196 |
split=None,
|
197 |
)
|
|
|
201 |
)
|
202 |
elif ds_from_hub:
|
203 |
ds = load_dataset(
|
204 |
+
config_dataset.path,
|
205 |
+
name=config_dataset.name,
|
206 |
streaming=False,
|
207 |
+
data_files=config_dataset.data_files,
|
208 |
token=use_auth_token,
|
209 |
)
|
210 |
else:
|
211 |
+
if isinstance(config_dataset.data_files, str):
|
212 |
fp = hf_hub_download(
|
213 |
+
repo_id=config_dataset.path,
|
214 |
repo_type="dataset",
|
215 |
+
filename=config_dataset.data_files,
|
216 |
)
|
217 |
+
elif isinstance(config_dataset.data_files, list):
|
218 |
fp = []
|
219 |
+
for file in config_dataset.data_files:
|
220 |
fp.append(
|
221 |
hf_hub_download(
|
222 |
+
repo_id=config_dataset.path,
|
223 |
repo_type="dataset",
|
224 |
filename=file,
|
225 |
)
|
|
|
229 |
"data_files must be either a string or list of strings"
|
230 |
)
|
231 |
ds = load_dataset(
|
232 |
+
"json",
|
233 |
+
name=config_dataset.name,
|
234 |
+
data_files=fp,
|
235 |
+
streaming=False,
|
236 |
+
split=None,
|
237 |
)
|
238 |
if not ds:
|
239 |
raise ValueError("unhandled dataset load")
|
240 |
# support for using a subset of the data
|
241 |
+
if config_dataset.shards:
|
242 |
if "train" in ds:
|
243 |
ds = ds.shuffle(seed=seed)["train"].shard(
|
244 |
+
num_shards=config_dataset.shards, index=0
|
245 |
)
|
246 |
else:
|
247 |
+
ds = ds.shuffle(seed=seed).shard(
|
248 |
+
num_shards=config_dataset.shards, index=0
|
249 |
+
)
|
250 |
|
251 |
d_base_type = d_prompt_style = None
|
252 |
+
d_type = config_dataset.type
|
253 |
if isinstance(d_type, str):
|
254 |
d_type_split = d_type.split(":")
|
255 |
d_base_type = d_type_split[0]
|
|
|
258 |
ds = ds["train"]
|
259 |
elif (
|
260 |
isinstance(ds, DatasetDict)
|
261 |
+
and config_dataset.train_on_split
|
262 |
+
and config_dataset.train_on_split in ds
|
263 |
):
|
264 |
+
ds = ds[config_dataset.train_on_split]
|
265 |
elif isinstance(ds, DatasetDict):
|
266 |
raise ValueError(
|
267 |
+
f"no train split found for dataset {config_dataset.path}, you may specify a split with 'train_on_split: `"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
268 |
)
|
269 |
+
|
270 |
+
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
|
271 |
+
config_dataset=config_dataset,
|
272 |
+
dataset=ds,
|
273 |
+
tokenizer=tokenizer,
|
274 |
+
cfg=cfg,
|
275 |
+
d_base_type=d_base_type,
|
276 |
+
d_prompt_style=d_prompt_style,
|
277 |
+
)
|
278 |
+
datasets.append(dataset_wrapper)
|
279 |
+
prompters.append(dataset_prompter)
|
280 |
+
|
281 |
LOG.info("merging datasets")
|
282 |
dataset = concatenate_datasets(datasets)
|
283 |
|
|
|
295 |
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
296 |
)
|
297 |
|
298 |
+
return dataset, prompters
|
299 |
|
300 |
|
301 |
def load_prepare_datasets(
|
302 |
tokenizer: PreTrainedTokenizerBase,
|
303 |
cfg,
|
304 |
default_dataset_prepared_path,
|
305 |
+
) -> Tuple[Dataset, Dataset, List[Any]]:
|
306 |
max_packed_sequence_len = (
|
307 |
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
308 |
)
|
|
|
311 |
) # make sure we don't accidentally set it larger than sequence_len
|
312 |
|
313 |
tokenizer_name = tokenizer.__class__.__name__
|
314 |
+
prompters = []
|
315 |
if cfg.max_packed_sequence_len is not None:
|
316 |
# see if we can go ahead and load the stacked dataset
|
317 |
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
|
|
|
367 |
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
368 |
)
|
369 |
else:
|
370 |
+
dataset, prompters = load_tokenized_prepared_datasets(
|
371 |
tokenizer, cfg, default_dataset_prepared_path
|
372 |
)
|
373 |
|
|
|
409 |
private=True,
|
410 |
)
|
411 |
else:
|
412 |
+
dataset, prompters = load_tokenized_prepared_datasets(
|
413 |
tokenizer, cfg, default_dataset_prepared_path
|
414 |
)
|
415 |
|
|
|
460 |
train_dataset = dataset
|
461 |
eval_dataset = None
|
462 |
|
463 |
+
return train_dataset, eval_dataset, prompters
|
464 |
+
|
465 |
+
|
466 |
+
def get_dataset_wrapper(
|
467 |
+
config_dataset, dataset, tokenizer, cfg, d_base_type, d_prompt_style
|
468 |
+
):
|
469 |
+
dataset_wrapper = None
|
470 |
+
dataset_prompter = None
|
471 |
+
|
472 |
+
if (
|
473 |
+
"input_ids" in dataset.features
|
474 |
+
and "attention_mask" in dataset.features
|
475 |
+
and "labels" in dataset.features
|
476 |
+
):
|
477 |
+
# dataset is already tokenized, just drop it straight in
|
478 |
+
dataset_prompter = UnsupportedPrompter()
|
479 |
+
dataset_wrapper = dataset
|
480 |
+
elif isinstance(config_dataset.type, DictDefault):
|
481 |
+
ds_strategy = load(
|
482 |
+
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
|
483 |
+
)
|
484 |
+
dataset_prompter = UnsupportedPrompter()
|
485 |
+
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
486 |
+
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
|
487 |
+
dataset_prompter = UnsupportedPrompter()
|
488 |
+
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
489 |
+
elif d_base_type == "alpaca":
|
490 |
+
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
491 |
+
ds_strategy = AlpacaPromptTokenizingStrategy(
|
492 |
+
dataset_prompter,
|
493 |
+
tokenizer,
|
494 |
+
cfg.train_on_inputs,
|
495 |
+
cfg.sequence_len,
|
496 |
+
)
|
497 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
498 |
+
dataset_wrapper = ds_wrapper
|
499 |
+
elif d_base_type == "explainchoice":
|
500 |
+
dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
|
501 |
+
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
502 |
+
dataset_prompter,
|
503 |
+
tokenizer,
|
504 |
+
cfg.train_on_inputs,
|
505 |
+
cfg.sequence_len,
|
506 |
+
)
|
507 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
508 |
+
dataset_wrapper = ds_wrapper
|
509 |
+
elif d_base_type == "concisechoice":
|
510 |
+
dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
|
511 |
+
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
512 |
+
dataset_prompter,
|
513 |
+
tokenizer,
|
514 |
+
cfg.train_on_inputs,
|
515 |
+
cfg.sequence_len,
|
516 |
+
)
|
517 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
518 |
+
dataset_wrapper = ds_wrapper
|
519 |
+
elif d_base_type == "summarizetldr":
|
520 |
+
dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
|
521 |
+
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
|
522 |
+
dataset_prompter,
|
523 |
+
tokenizer,
|
524 |
+
cfg.train_on_inputs,
|
525 |
+
cfg.sequence_len,
|
526 |
+
)
|
527 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
528 |
+
dataset_wrapper = ds_wrapper
|
529 |
+
elif d_base_type == "jeopardy":
|
530 |
+
dataset_prompter = JeopardyPrompter(d_prompt_style)
|
531 |
+
ds_strategy = JeopardyPromptTokenizingStrategy(
|
532 |
+
dataset_prompter,
|
533 |
+
tokenizer,
|
534 |
+
cfg.train_on_inputs,
|
535 |
+
cfg.sequence_len,
|
536 |
+
)
|
537 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
538 |
+
dataset_wrapper = ds_wrapper
|
539 |
+
elif d_base_type == "oasst":
|
540 |
+
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
541 |
+
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
542 |
+
dataset_prompter,
|
543 |
+
tokenizer,
|
544 |
+
cfg.train_on_inputs,
|
545 |
+
cfg.sequence_len,
|
546 |
+
)
|
547 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
548 |
+
dataset_wrapper = ds_wrapper
|
549 |
+
elif d_base_type == "gpteacher":
|
550 |
+
dataset_prompter = GPTeacherPrompter(d_prompt_style)
|
551 |
+
ds_strategy = GPTeacherPromptTokenizingStrategy(
|
552 |
+
dataset_prompter,
|
553 |
+
tokenizer,
|
554 |
+
cfg.train_on_inputs,
|
555 |
+
cfg.sequence_len,
|
556 |
+
)
|
557 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
558 |
+
dataset_wrapper = ds_wrapper
|
559 |
+
elif d_base_type == "reflection":
|
560 |
+
dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
|
561 |
+
ds_strategy = AlpacaReflectionPTStrategy(
|
562 |
+
dataset_prompter,
|
563 |
+
tokenizer,
|
564 |
+
cfg.train_on_inputs,
|
565 |
+
cfg.sequence_len,
|
566 |
+
)
|
567 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
568 |
+
dataset_wrapper = ds_wrapper
|
569 |
+
else:
|
570 |
+
suffix = ""
|
571 |
+
if ":load_" in config_dataset.type:
|
572 |
+
suffix = f" Did you mean {config_dataset.type.replace(':load_', '.load_')}?"
|
573 |
+
LOG.error(
|
574 |
+
f"unhandled prompt tokenization strategy: {config_dataset.type}. {suffix}"
|
575 |
+
)
|
576 |
+
raise ValueError(
|
577 |
+
f"unhandled prompt tokenization strategy: {config_dataset.type} {suffix}"
|
578 |
+
)
|
579 |
+
|
580 |
+
return dataset_wrapper, dataset_prompter
|
581 |
|
582 |
|
583 |
def encode_pretraining(
|