support custom field for completion from yml (#580)
Browse files* support custom field for completion from yml
* remove legacy completion check and add doc
* update README docs
README.md
CHANGED
@@ -322,6 +322,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
|
322 |
- path: EleutherAI/pile
|
323 |
name: enron_emails
|
324 |
type: completion # format from earlier
|
|
|
325 |
|
326 |
# huggingface repo with multiple named configurations/subsets
|
327 |
datasets:
|
@@ -444,6 +445,9 @@ datasets:
|
|
444 |
# 'no_input_format' cannot include {input}
|
445 |
no_input_format: "{instruction} "
|
446 |
|
|
|
|
|
|
|
447 |
# axolotl attempts to save the dataset as an arrow after packing the data together so
|
448 |
# subsequent training attempts load faster, relative path
|
449 |
dataset_prepared_path: data/last_run_prepared
|
|
|
322 |
- path: EleutherAI/pile
|
323 |
name: enron_emails
|
324 |
type: completion # format from earlier
|
325 |
+
field: text # Optional[str] default: text, field to use for completion data
|
326 |
|
327 |
# huggingface repo with multiple named configurations/subsets
|
328 |
datasets:
|
|
|
445 |
# 'no_input_format' cannot include {input}
|
446 |
no_input_format: "{instruction} "
|
447 |
|
448 |
+
# for completions datsets, uses the provided field if not `text`
|
449 |
+
field:
|
450 |
+
|
451 |
# axolotl attempts to save the dataset as an arrow after packing the data together so
|
452 |
# subsequent training attempts load faster, relative path
|
453 |
dataset_prepared_path: data/last_run_prepared
|
src/axolotl/prompt_strategies/__init__.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
"""Module to load prompt strategies."""
|
2 |
|
3 |
import importlib
|
|
|
4 |
|
5 |
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
|
6 |
|
@@ -16,6 +17,10 @@ def load(strategy, tokenizer, cfg, ds_cfg):
|
|
16 |
load_kwargs = {}
|
17 |
if strategy == "user_defined":
|
18 |
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
|
|
|
|
|
|
|
|
|
19 |
return func(tokenizer, cfg, **load_kwargs)
|
20 |
except Exception: # pylint: disable=broad-exception-caught
|
21 |
return None
|
|
|
1 |
"""Module to load prompt strategies."""
|
2 |
|
3 |
import importlib
|
4 |
+
import inspect
|
5 |
|
6 |
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
|
7 |
|
|
|
17 |
load_kwargs = {}
|
18 |
if strategy == "user_defined":
|
19 |
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
|
20 |
+
else:
|
21 |
+
sig = inspect.signature(func)
|
22 |
+
if "ds_cfg" in sig.parameters:
|
23 |
+
load_kwargs["ds_cfg"] = ds_cfg
|
24 |
return func(tokenizer, cfg, **load_kwargs)
|
25 |
except Exception: # pylint: disable=broad-exception-caught
|
26 |
return None
|
src/axolotl/prompt_strategies/completion.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Basic completion text
|
3 |
+
"""
|
4 |
+
from typing import Any, Dict, Optional
|
5 |
+
|
6 |
+
from axolotl.prompt_tokenizers import CompletionPromptTokenizingStrategy
|
7 |
+
from axolotl.prompters import CompletionPrompter
|
8 |
+
|
9 |
+
|
10 |
+
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
11 |
+
strat = CompletionPromptTokenizingStrategy(
|
12 |
+
CompletionPrompter(),
|
13 |
+
tokenizer,
|
14 |
+
cfg.train_on_inputs,
|
15 |
+
cfg.sequence_len,
|
16 |
+
)
|
17 |
+
if ds_cfg and "field" in ds_cfg:
|
18 |
+
strat.field = ds_cfg["field"]
|
19 |
+
|
20 |
+
return strat
|
src/axolotl/prompt_tokenizers.py
CHANGED
@@ -245,8 +245,31 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
|
245 |
Tokenizing strategy for Completion prompts.
|
246 |
"""
|
247 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
def tokenize_prompt(self, prompt):
|
249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
tokenized_full_prompt = self._tokenize(full_prompt)
|
251 |
|
252 |
return tokenized_full_prompt
|
|
|
245 |
Tokenizing strategy for Completion prompts.
|
246 |
"""
|
247 |
|
248 |
+
_field: str = "text"
|
249 |
+
|
250 |
+
@property
|
251 |
+
def field(self) -> str:
|
252 |
+
return self._field
|
253 |
+
|
254 |
+
@field.setter
|
255 |
+
def field(self, new_field: str):
|
256 |
+
self._field = new_field
|
257 |
+
|
258 |
+
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
259 |
+
return (
|
260 |
+
prompt[self.field],
|
261 |
+
"",
|
262 |
+
"",
|
263 |
+
)
|
264 |
+
|
265 |
def tokenize_prompt(self, prompt):
|
266 |
+
(
|
267 |
+
instruction,
|
268 |
+
_,
|
269 |
+
_,
|
270 |
+
) = self.parse_instruction_fields(prompt)
|
271 |
+
|
272 |
+
full_prompt = self._build_full_prompt(instruction, None, None)
|
273 |
tokenized_full_prompt = self._tokenize(full_prompt)
|
274 |
|
275 |
return tokenized_full_prompt
|
src/axolotl/utils/data.py
CHANGED
@@ -22,7 +22,6 @@ from axolotl.prompt_tokenizers import (
|
|
22 |
AlpacaMultipleChoicePromptTokenizingStrategy,
|
23 |
AlpacaPromptTokenizingStrategy,
|
24 |
AlpacaReflectionPTStrategy,
|
25 |
-
CompletionPromptTokenizingStrategy,
|
26 |
GPTeacherPromptTokenizingStrategy,
|
27 |
JeopardyPromptTokenizingStrategy,
|
28 |
OpenAssistantPromptTokenizingStrategy,
|
@@ -31,7 +30,6 @@ from axolotl.prompt_tokenizers import (
|
|
31 |
)
|
32 |
from axolotl.prompters import (
|
33 |
AlpacaPrompter,
|
34 |
-
CompletionPrompter,
|
35 |
GPTeacherPrompter,
|
36 |
JeopardyPrompter,
|
37 |
MultipleChoiceConcisePrompter,
|
@@ -327,15 +325,6 @@ def load_tokenized_prepared_datasets(
|
|
327 |
)
|
328 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
329 |
datasets.append(ds_wrapper)
|
330 |
-
elif d_base_type == "completion":
|
331 |
-
ds_strategy = CompletionPromptTokenizingStrategy(
|
332 |
-
CompletionPrompter(),
|
333 |
-
tokenizer,
|
334 |
-
cfg.train_on_inputs,
|
335 |
-
cfg.sequence_len,
|
336 |
-
)
|
337 |
-
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
338 |
-
datasets.append(ds_wrapper)
|
339 |
else:
|
340 |
suffix = ""
|
341 |
if ":load_" in d.type:
|
|
|
22 |
AlpacaMultipleChoicePromptTokenizingStrategy,
|
23 |
AlpacaPromptTokenizingStrategy,
|
24 |
AlpacaReflectionPTStrategy,
|
|
|
25 |
GPTeacherPromptTokenizingStrategy,
|
26 |
JeopardyPromptTokenizingStrategy,
|
27 |
OpenAssistantPromptTokenizingStrategy,
|
|
|
30 |
)
|
31 |
from axolotl.prompters import (
|
32 |
AlpacaPrompter,
|
|
|
33 |
GPTeacherPrompter,
|
34 |
JeopardyPrompter,
|
35 |
MultipleChoiceConcisePrompter,
|
|
|
325 |
)
|
326 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
327 |
datasets.append(ds_wrapper)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
else:
|
329 |
suffix = ""
|
330 |
if ":load_" in d.type:
|