File size: 13,371 Bytes
bcc039b
aeb95f1
bcc039b
 
 
 
 
 
fc3399e
 
 
 
bcc039b
 
 
aeb95f1
 
 
 
 
bcc039b
 
 
 
 
 
 
 
aeb95f1
bcc039b
 
fc3399e
bcc039b
 
 
 
 
 
 
 
 
 
 
08b8c7c
bcc039b
 
 
08b8c7c
 
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aeb95f1
7622d28
aeb95f1
7622d28
aeb95f1
 
7622d28
 
 
 
 
 
 
 
 
08b8c7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7622d28
 
 
08b8c7c
7622d28
 
 
 
08b8c7c
 
7622d28
 
 
 
08b8c7c
 
7622d28
 
bcc039b
 
 
 
 
 
 
08b8c7c
083656c
bcc039b
 
 
 
08b8c7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcc039b
 
083656c
 
 
bcc039b
 
 
 
 
 
 
 
 
 
 
083656c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
08b8c7c
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08b8c7c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
# Copyright (c) Meta Platforms, Inc. and affiliates.
from enum import Enum
from typing import Any

import numpy as np
from pydantic import BaseModel, ConfigDict

from bytelatent.data.data_types import Batch, BltSequence
from bytelatent.data.iterators.abstract_iterator import (
    PydanticIteratorState,
    StatefulIterator,
)
from bytelatent.data.iterators.sampling_iterator import SamplingIteratorState


class PackingMode(str, Enum):
    BYTES = "bytes"
    PATCHING = "patching"


class PackingArgs(BaseModel):
    model_config = ConfigDict(extra="forbid")
    batch_size: int
    seq_len: int
    pad_id: int
    max_length: int | None
    pad_to_max_length: bool
    enable_byte_ngrams: bool
    packing_mode: PackingMode


class PackingIteratorState(PydanticIteratorState):
    model_config = ConfigDict(extra="forbid")
    sequence_iterator_state: SamplingIteratorState
    packing_args: PackingArgs

    def build(self) -> "PackingIterator":
        return PackingIterator(
            sequence_iterator=self.sequence_iterator_state.build(),
            packing_args=self.packing_args,
        )


def _merge_patch_seq_masks(bs: int, slen: int, mask_seqs: list[list[bool]]):
    assert len(mask_seqs) == bs
    lens = [len(m) for m in mask_seqs]
    if all(all(m) for m in mask_seqs) and all(lens[0] == l for l in lens):
        return np.ones((bs, slen), dtype=bool)
    assert slen == max(lens) - 1, f"slen={slen} != max(lens)-1={max(lens) - 1}"
    mask = np.zeros((bs, slen), dtype=bool)
    for i, m in enumerate(mask_seqs):
        if m is None:
            print(
                "Did not implement None mask, the mask should be True for all toks, so we need to pass that to this function."
            )
            raise NotImplementedError
        mask[i][: len(mask_seqs[i]) - 1] = mask_seqs[i][1:]
    return mask


def truncate_batch(
    batch: Batch,
    max_length: int,
    pad_id: int,
    pad_to_max_length: bool = False,
    *,
    enable_byte_ngrams: bool,
):
    """
    Truncate the x to a given size, making sure we remove the corresponding patch sizes in patch_lenghts
    and fixing the batch.mask.

    batch.patch_lengths has unchanged shape
    x,y, and mask may reduce in size
    """
    if batch.patch_lengths is None:
        return batch

    seq_lengths = batch.patch_lengths.sum(axis=1)
    max_length_adj = max_length + 1
    if np.any(seq_lengths > max_length_adj):
        for i in range(batch.x.shape[0]):
            if seq_lengths[i] > max_length_adj:
                # Find id of patch that tips over max_length + 1
                count, j = 0, 0
                while count + batch.patch_lengths[i, j] <= max_length_adj:
                    count += batch.patch_lengths[i, j]
                    j += 1
                # Edit the batch
                assert j < batch.patch_lengths.shape[1]
                batch.x[i, max_length:] = pad_id
                batch.y[i, max_length:] = pad_id
                if batch.mask is not None:
                    batch.mask[i, max_length:] = False
                batch.patch_lengths[i, j:] = 0
                batch.patch_lengths[i, j] = max_length_adj - count

        # Truncate if necessary.
        if max_length < batch.x.shape[1]:
            batch.x = batch.x[:, :max_length]
            batch.y = batch.y[:, :max_length]
            if batch.mask is not None:
                batch.mask = batch.mask[:, :max_length]

    # Right pad to max_length if necessary
    elif pad_to_max_length:
        if batch.x.shape[1] < max_length:
            # NOTE: this has to be done on an actual patch.
            non_zero_indices = (batch.patch_lengths != 0).sum(axis=1) - 1
            non_zero_indices = np.maximum(0, non_zero_indices)
            batch.patch_lengths[range(len(batch.patch_lengths)), non_zero_indices] += (
                max_length - batch.x.shape[1]
            )
            # TODO: We could get rid of many of these complications by moving this funciton directly in the dataloader.
            x = np.full((batch.x.shape[0], max_length), pad_id, dtype=batch.x.dtype)
            x[:, : batch.x.shape[1]] = batch.x
            batch.x = x
        if batch.y.shape[1] < max_length:
            y = np.full((batch.y.shape[0], max_length), pad_id, dtype=batch.y.dtype)
            y[:, : batch.y.shape[1]] = batch.y
            batch.y = y
        if batch.mask is not None and batch.mask.shape[1] < max_length:
            mask = np.full(
                (batch.mask.shape[0], max_length), False, dtype=batch.mask.dtype
            )
            mask[:, : batch.mask.shape[1]] = batch.mask
            batch.mask = mask

    assert batch.x.shape[1] <= max_length
    assert batch.y.shape[1] <= max_length
    assert batch.mask is None or batch.mask.shape[1] <= max_length
    assert np.all(max_length_adj - batch.patch_lengths.sum(axis=1) == 0)
    if pad_to_max_length:
        assert batch.x.shape[1] == max_length
        assert batch.y.shape[1] == max_length
        assert batch.mask is None or batch.mask.shape[1] == max_length
    if enable_byte_ngrams:
        raise NotImplementedError()
        # (num_ngram, batch_size, seq_len)
        ngram_ids = np.array(tokenizer.encode_token_ngrams(batch.x))
        assert ngram_ids.shape[2] == batch.x.shape[1]
    else:
        ngram_ids = None
    batch.ngram_ids = ngram_ids


class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
    def __init__(
        self,
        sequence_iterator: StatefulIterator[BltSequence, Any],
        *,
        packing_args: PackingArgs,
    ):
        self.sequence_iterator = sequence_iterator
        self.packing_args = packing_args

    def get_state(self):
        return PackingIteratorState(
            sequence_iterator_state=self.sequence_iterator.get_state(),
            packing_args=self.packing_args,
        )

    def create_iter(self):
        if self.packing_args.packing_mode == PackingMode.BYTES:
            return self._create_iter_from_bytes()
        elif self.packing_args.packing_mode == PackingMode.PATCHING:
            return self._create_iter_from_patch_lengths()
        else:
            raise ValueError(f"Invalid patching mode: {self.packing_args.packing_mode}")

    def _create_iter_from_bytes(self):
        sequence_iter = self.sequence_iterator.create_iter()
        batch_size = self.packing_args.batch_size
        pad_id = self.packing_args.pad_id
        seq_len = self.packing_args.seq_len
        while True:
            tokens: list[list[int]] = []
            masks: list[list[bool]] = []
            stop_iteration = False
            try:
                for _ in range(self.packing_args.batch_size):
                    sequence = next(sequence_iter)
                    _tokens = sequence.tokens
                    _mask = sequence.mask
                    assert (
                        sequence.patch_lengths is None
                    ), "patch_lengths should not be used in byte packing"
                    tokens.append(_tokens)
                    masks.append(_mask)
            except StopIteration:
                # At this point, there will be no new sequences, so we need to stop
                # after yielding the already accumulated data (one batch).
                # In this case, either:
                # 1. We have a complete batch, so things go as normal
                # 2. We have an incomplete batch, but due to creating a right sized batch,
                #   then filling the values in, this case is automatically handled.
                stop_iteration = True

            x = np.full((batch_size, seq_len), fill_value=pad_id)
            y = np.full((batch_size, seq_len), fill_value=pad_id)
            m = np.zeros((batch_size, seq_len), dtype=np.bool)

            for i, tok_seq in enumerate(tokens):
                x[i, : len(tok_seq)] = tok_seq
                y[i, : len(tok_seq) - 1] = tok_seq[1:]
                m[i, : len(tok_seq)] = masks[i]
            batch = Batch(x=x, y=y, mask=m)
            assert (
                batch.mask is None or np.sum(x != pad_id) == batch.mask.sum()
            ), f"{np.sum(x != pad_id)} != {batch.mask.sum()}"
            yield batch
            if stop_iteration:
                break

    def _create_iter_from_patch_lengths(self):
        sequence_iter = self.sequence_iterator.create_iter()
        batch_size = self.packing_args.batch_size
        pad_id = self.packing_args.pad_id
        seq_len = self.packing_args.seq_len
        pad_to_max_length = self.packing_args.pad_to_max_length
        enable_byte_ngrams = self.packing_args.enable_byte_ngrams
        max_length = self.packing_args.max_length
        assert max_length is not None
        final_leftover_batch = False
        while True:
            tokens: list[list[int]] = []
            masks: list[list[bool]] = []
            patch_lengths: list[list[int]] = []
            stop_iteration = False
            try:
                for _ in range(self.packing_args.batch_size):
                    sequence = next(sequence_iter)
                    _tokens = sequence.tokens
                    _mask = sequence.mask
                    _patch_lengths = sequence.patch_lengths
                    assert (
                        _patch_lengths is not None
                    ), "patch lengths are required for packing based on patches."
                    # Reminder: seq_len is in terms of patches
                    assert len(sequence.patch_lengths) == self.packing_args.seq_len
                    last_patch_length = 0
                    if _patch_lengths[0] > 1:
                        last_patch_length = _patch_lengths[-1]
                        _patch_lengths[0] -= 1
                        _patch_lengths = [1] + _patch_lengths[:-1]
                    tokens.append(_tokens[: len(_tokens) - last_patch_length])
                    masks.append(_mask[: len(_mask) - last_patch_length])
                    patch_lengths.append(_patch_lengths)
            except StopIteration:
                stop_iteration = True

            if len(tokens) == 0 and stop_iteration:
                break

            x_patch_lengths = np.array(patch_lengths)
            assert (
                x_patch_lengths.shape[1] == seq_len
            ), f"{x_patch_lengths.shape[1]} vs {seq_len}"
            # pad batch to same length
            tok_seq_len = max([len(toks) for toks in tokens]) - 1
            x = np.full((batch_size, tok_seq_len), fill_value=pad_id)
            y = np.full((batch_size, tok_seq_len), fill_value=pad_id)

            for i, tok_seq in enumerate(tokens):
                x[i, : len(tok_seq) - 1] = tok_seq[:-1]
                y[i, : len(tok_seq) - 1] = tok_seq[1:]
                # Adjust patch lengths to match x
                x_patch_lengths[i, -1] += tok_seq_len - (len(tok_seq) - 1)

            if x_patch_lengths.shape[0] < batch_size:
                if final_leftover_batch:
                    raise ValueError(
                        "There should only be one partial batch, but found multiple"
                    )
                final_leftover_batch = True
                assert len(masks) == len(x_patch_lengths)
                n_missing = batch_size - x_patch_lengths.shape[0]
                # Repeat the last patch length to validly pad it out, but
                # update the mask to ignore the row
                x_patch_lengths = np.vstack(
                    [
                        x_patch_lengths,
                        np.repeat(x_patch_lengths[-1:, :], n_missing, axis=0),
                    ]
                )
                for _ in range(n_missing):
                    masks.append([0] * tok_seq_len)
                assert len(masks) == batch_size

            assert x_patch_lengths.shape == (
                batch_size,
                seq_len,
            ), f"{x_patch_lengths.shape} vs {(batch_size, seq_len)}"

            if enable_byte_ngrams:
                raise NotImplementedError()
            else:
                ngram_ids = None

            batch = Batch(
                x=x,
                y=y,
                patch_lengths=x_patch_lengths,
                ngram_ids=ngram_ids,
                mask=_merge_patch_seq_masks(batch_size, tok_seq_len, masks),
            )

            assert (
                x_patch_lengths.sum() == x.size + batch_size
            ), f"{x_patch_lengths.sum()} != {x.size + batch_size}"
            assert (
                batch.mask is None or np.sum(x != pad_id) == batch.mask.sum()
            ), f"{np.sum(x != pad_id)} != {batch.mask.sum()}"
            assert np.all(
                x_patch_lengths[:, 0] == 1
            ), f"first patch should always be 1, {x_patch_lengths[:, 0]}"
            # cuda_gb_allocated = (torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024)
            # cuda_gb_reserved = torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024
            # print(f"dataloader cuda_gb_allocated: {cuda_gb_allocated}, cuda_gb_reserved: {cuda_gb_reserved}")
            truncate_batch(
                batch,
                max_length=max_length,
                pad_id=pad_id,
                pad_to_max_length=pad_to_max_length,
                enable_byte_ngrams=enable_byte_ngrams,
            )
            yield batch
            if stop_iteration:
                break