File size: 3,019 Bytes
bcc039b
 
 
fc3399e
 
 
 
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Copyright (c) Meta Platforms, Inc. and affiliates.

from bytelatent.constants import BLT_DATA
from bytelatent.data.iterators.dev_iterators import (
    BltTestIterator,
    BltTestWithEntropiesIterator,
)
from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs


def test_preprocess_iter():
    total = 3
    tokenizer_args = TokenizerArgs(
        name="blt",
        init_kwargs={
            "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
        },
    )
    for mode in [
        PatchingModeEnum.bpe,
        PatchingModeEnum.space,
    ]:
        data_it = BltTestIterator(total)
        patcher_args = PatcherArgs(patching_mode=mode)
        example_it = PreprocessIterator(
            data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args
        )
        count = 0
        for example in example_it.create_iter():
            assert isinstance(example.tokens, list)
            assert isinstance(example.tokens[0], int)
            # BOS and EOS
            assert len(example.tokens) == len(example.text) + 2
            assert example.mask is not None
            assert len(example.tokens) == len(example.mask)
            count += 1

        assert count == total


def test_non_entropy_patch_iter():
    total = 3
    tokenizer_args = TokenizerArgs(
        name="blt",
        init_kwargs={
            "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
        },
    )
    for mode in [
        PatchingModeEnum.bpe,
        PatchingModeEnum.space,
    ]:
        patcher_args = PatcherArgs(patching_mode=mode)
        data_it = BltTestIterator(total)
        example_it = PreprocessIterator(
            data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args
        )

        count = 0
        for example in example_it.create_iter():
            assert isinstance(example.patch_lengths, list)
            assert isinstance(example.patch_lengths[0], int)
            assert len(example.tokens) == sum(example.patch_lengths)
            count += 1

        assert count == total


def test_entropy_patch_iter():
    total = 2
    patcher_args = PatcherArgs(
        patching_mode=PatchingModeEnum.entropy, threshold=1.335442066192627
    )
    tokenizer_args = TokenizerArgs(
        name="blt",
        init_kwargs={
            "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
        },
    )
    data_it = BltTestWithEntropiesIterator(total)
    example_it = PreprocessIterator(
        data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args
    )

    count = 0
    for example in example_it.create_iter():
        assert isinstance(example.patch_lengths, list)
        assert isinstance(example.patch_lengths[0], int)
        assert len(example.tokens) == sum(example.patch_lengths)
        count += 1

    assert count == total