|
"""Module for testing streaming dataset sequence packing""" |
|
import pytest |
|
from datasets import concatenate_datasets, load_dataset |
|
from torch.utils.data import DataLoader, RandomSampler |
|
from transformers import AutoTokenizer |
|
|
|
from axolotl.datasets import TokenizedPromptDataset |
|
from axolotl.prompt_strategies.completion import load |
|
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq |
|
from axolotl.utils.dict import DictDefault |
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths |
|
|
|
|
|
@pytest.fixture(name="tokenizer") |
|
def fixture_tokenizer(): |
|
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") |
|
tokenizer.pad_token = "</s>" |
|
return tokenizer |
|
|
|
|
|
@pytest.fixture(name="max_seq_length") |
|
def fixture_max_seq_length(): |
|
return 4096 |
|
|
|
|
|
class TestBatchedSamplerPacking: |
|
""" |
|
Test class for packing streaming dataset sequences |
|
""" |
|
|
|
@pytest.mark.parametrize( |
|
"batch_size, num_workers", |
|
[ |
|
(1, 0), |
|
(2, 0), |
|
(1, 2), |
|
(2, 2), |
|
], |
|
) |
|
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length): |
|
import axolotl.monkeypatch.data.batch_dataset_fetcher |
|
|
|
dataset = load_dataset( |
|
"Trelis/tiny-shakespeare", |
|
split="train", |
|
) |
|
|
|
cfg = DictDefault( |
|
{ |
|
"train_on_inputs": True, |
|
"sequence_len": max_seq_length, |
|
} |
|
) |
|
ds_cfg = DictDefault( |
|
{ |
|
"field": "Text", |
|
} |
|
) |
|
completion_strategy = load(tokenizer, cfg, ds_cfg) |
|
dataset_wrapper = TokenizedPromptDataset( |
|
completion_strategy, |
|
dataset, |
|
) |
|
train_dataset = concatenate_datasets([dataset_wrapper]) |
|
batch_sampler = MultipackBatchSampler( |
|
sampler=RandomSampler(train_dataset), |
|
batch_size=batch_size, |
|
drop_last=True, |
|
batch_max_len=max_seq_length, |
|
lengths=get_dataset_lengths(train_dataset), |
|
) |
|
|
|
loader = DataLoader( |
|
train_dataset, |
|
batch_sampler=batch_sampler, |
|
collate_fn=V2BatchSamplerDataCollatorForSeq2Seq( |
|
tokenizer=tokenizer, |
|
padding=True, |
|
pad_to_multiple_of=max_seq_length, |
|
return_tensors="pt", |
|
), |
|
num_workers=num_workers, |
|
) |
|
inputs = next(iter(loader)) |
|
|
|
assert inputs["input_ids"].shape == (batch_size, max_seq_length) |
|
assert inputs["labels"].shape == (batch_size, max_seq_length) |
|
assert inputs["attention_mask"].shape == (batch_size, max_seq_length) |
|
|
|
assert inputs["input_ids"].tolist()[0][0] == 2 |
|
assert inputs["labels"].tolist()[0][0] == -100 |
|
assert inputs["attention_mask"].tolist()[0][0] == 0 |
|
assert inputs["attention_mask"].tolist()[0][-1] > 1 |
|
|
|
if batch_size >= 2: |
|
assert inputs["input_ids"].tolist()[1][0] == 2 |
|
assert inputs["labels"].tolist()[1][0] == -100 |
|
assert inputs["attention_mask"].tolist()[1][0] == 0 |
|
assert inputs["attention_mask"].tolist()[1][-1] > 1 |
|
|