Template-free prompt construction with the input_output
format
Background
Masking Inputs
One of the most popular features of axolotl is setting the following configuration value:
train_on_inputs: false
If you declare a dataset formats
such as alpaca
or chatml
, axolotl knows what is an input
(i.e. human) vs. an output (i.e. the assistant) and masks the input
labels so that your model can focus on predicting the outputs only.
You may not want prompt templates
However, there are many situations where you don't want to use one of these formats or templates (I usually don't!). This is because they can:
- Add unnecessary boilerplate to your prompts.
- Create artifacts like special delimiters
<|im_start|>
that can quickly become footguns if you don't include them correctly at inference time. - Enforce a chat interface when you do not want one. Sometimes you just want to fine-tune a model to a very specific task and do NOT want multi-turn conversations, roles, etc.
- Limit you to only certain roles that the template allows.
The input_output
format
You can construct your prompts without a template by using the
input_output
format, by setting type: input_output
in your
configuration file like this:
config.yml
train_on_inputs: false # Mask segments of your data
datasets:
- path: output.jsonl
type: input_output # use template free prompt construction
Unlike type: completion
, which is also template-free,
type: input_output
allows you to mask segments of your text. More
details on how this works are described below.
Usage
This is how you can use the input_output
format:
1. Prepare Data
To use the input_output
format, collect your data in the following
format into a jsonl file (below is the first row from the file
output
.jsonl` pretty printed):
$ head -n1 output.jsonl | python -m json.tool
{.cell-output .cell-output-stdout}
{
"segments": [
{
"label": true,
"text": "<s>Hello\n"
},
{
"label": true,
"text": "hi there!. "
},
{
"label": false,
"text": "goodbye "
},
{
"label": true,
"text": "farewell</s>"
}
]
}
Set label:false
when you want to mask a segment of text so that the
model isn't trained on it. Some things to keep in mind:
- EOS, BOS, spaces, newlines etc. are entirely up to you. Axolotl concatenates all the segments as-is. The tokenizer doesn't add anything additional. Notice how I added spaces, newlines,
<s>
(BOS), and</s>
(EOS) myself.- Make sure you check the materialized output to validate that the prompt is getting assembled how you like.
2. Use type: input_output
Let's materialize data with our output.jsonl
file by setting
type: input_output
in our axolotl config:
# training_config.yaml
base_model: mistralai/Mistral-7B-v0.1
data_seed: 49
seed: 49
datasets:
- path: output.jsonl
type: input_output
val_set_size: 0.1
sequence_len: 896
sample_packing: false
micro_batch_size: 2
gradient_accumulation_steps: 3
eval_batch_size: 2
num_epochs: 1
learning_rate: 0.0002
train_on_inputs: false
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
You can use the following command to materialize your data. The
--debug
flag will print the tokens, along with the labels so you can
verify that the correct items are being ignored:
$ python -m axolotl.cli.preprocess training_config.yaml --debug
...
[2024-03-05 23:36:46,969] [INFO] [axolotl.check_example_labels:35] [PID:607731] [RANK:0] <s>(1, 1) Hello(22557, 22557)
(13, 13) hi(12014, 12014) there(736, 736) !(28808, 28808) .(28723, 28723) (28705, 28705) good(-100, 1179) bye(-100, 17664) (-100, 28705) fare(19111, 19111) well(5458, 5458) </s>(2, 2)
The format is decoded_token
(label
, token_id
), for example,
<s>(1, 1)
means that the token is <s>
, the label is 1
and the
token_id is 1
. When the label is -100
then that token is ignored for
training.
3. Check the prompts
Here is another way to check the materialized output:
from transformers import AutoTokenizer
from datasets import load_from_disk
import yaml
directory = !ls last_run_prepared/
with open('training_config.yaml', 'r') as f:
cfg = yaml.safe_load(f)
model_id = cfg['base_model']
tok = AutoTokenizer.from_pretrained(model_id)
ds = load_from_disk(f'last_run_prepared/{directory[0]}/')
>>> row = ds[0]
>>> print(tok.decode(row['input_ids']))
<s> Hello
hi there!. goodbye farewell</s>
We can check that the right tokens are ingored by comparing the labels to each token:
import pandas as pd
pd.DataFrame([{'token': tok.decode(i), 'label': l, 'id':i} for i,l in
zip(row['input_ids'], row['labels'])])
token | label | id |
---|---|---|
0 | <s> | 1 |
1 | Hello | 22557 |
2 | \n | 13 |
3 | hi | 12014 |
4 | there | 736 |
5 | ! | 28808 |
6 | . | 28723 |
7 | 28705 | |
8 | good | -100 |
9 | bye | -100 |
10 | -100 | |
11 | fare | 19111 |
12 | well | 5458 |
13 | </s> | 2 |
If we look at the input data, the above table seems correct! (The jsonl version is repeated below for reference):
$ head -n1 output.jsonl | python -m json.tool
{.cell-output .cell-output-stdout}
{
"segments": [
{
"label": true,
"text": "<s>Hello\n"
},
{
"label": true,
"text": "hi there!. "
},
{
"label": false,
"text": "goodbye "
},
{
"label": true,
"text": "farewell</s>"
}
]
}