File size: 5,587 Bytes
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# Copyright (c) Meta Platforms, Inc. and affiliates.
import pickle
from pathlib import Path

import numpy as np

from bytelatent import ByteLatentError

LOOKUP_OFFSET = 4


def apply_lookup_table_wrapper(ngram_to_idx: dict[tuple, int], lookup_offset=1):
    """
    Wrapper function for applying the lookup table to each n-gram.

    :param ngram: Array of numbers representing an n-gram.
    :param lookup_table: Dictionary where keys are tuples (n-grams) and values are the desired outputs.
    :param lookup_offset: Offset to add to the lookup result.
    :return: The value associated with the n-gram tuple in the dictionary, or None if not found.
    """

    def apply_lookup_table(ngram):
        """
        Function to apply to each n-gram: converts it to a tuple and looks it up in a dictionary.

        :param ngram: Array of numbers representing an n-gram.
        :return: The value associated with the n-gram tuple in the dictionary, or None if not found.
        """
        # Convert the n-gram to a tuple
        ngram_tuple = tuple(ngram)

        if ngram_tuple not in ngram_to_idx:
            return 0
        else:
            return ngram_to_idx[ngram_tuple] + lookup_offset

    return apply_lookup_table


def get_byte_ngrams_ids(
    byte_array: np.ndarray, n: int, ngram_to_idx: dict[tuple, int], pad_value=0
):
    """
    Generate n-grams from a 2D numpy array.

    :param n: The length of each n-gram.
    :param pad_value: The value used for padding of the byte values to maintain the same dimensions for the n-grams.
    :return: A 2D numpy array where each element is the ID of an n-gram offset by LOOKUP_OFFSET.
    """
    num_rows, num_cols = byte_array.shape

    # Create an array to hold the padded version of the original array
    padded_array = np.pad(
        byte_array, ((0, 0), (n - 1, 0)), mode="constant", constant_values=pad_value
    )

    # Use stride tricks to avoid explicit looping
    strided = np.lib.stride_tricks.as_strided
    shape = (num_rows, num_cols, n)
    strides = padded_array.strides[:2] + (padded_array.strides[1],)
    ngrams = strided(padded_array, shape=shape, strides=strides)

    ngram_ids = np.apply_along_axis(
        apply_lookup_table_wrapper(ngram_to_idx, lookup_offset=LOOKUP_OFFSET), 2, ngrams
    )
    assert ngram_ids.shape == byte_array.shape
    return ngram_ids


def reload_tables(
    ngram_table_dir: str, ngram_to_size: dict[int, int], offset: int = LOOKUP_OFFSET
) -> tuple[dict[int, list], dict[tuple, int], dict[int, int]]:
    """
    Reload lookup tables from a directory. Reload only the ngrams in the dictionary and per ngram,
    only load up to the max specified size. Return the actual number of ngrams taken per ngram size.
    """
    idx_to_ngram_tables = {}
    ngram_to_idx_tables = {}
    vocab_sizes = {}
    for ngram, size in ngram_to_size.items():
        with open(Path(ngram_table_dir) / f"ngram-{ngram}.pickle", "rb") as f:
            # These are already sorted by count
            # Value: tuple of: count, ngram, dataset
            ngram_data: list[tuple[tuple, tuple[int, int, str]]] = pickle.load(f)[
                "counts"
            ]
            table = [ngram for ngram, _ in ngram_data][:size]
            if len(table) != size:
                raise ValueError(
                    f"Ngram table for {ngram}-gram is not large enough to get {size} ngrams, max size is {len(ngram_data)}"
                )
            ngram_to_idx = {ngram: idx for idx, ngram in enumerate(table)}
            actual_size = len(table)
            idx_to_ngram_tables[ngram] = table
            ngram_to_idx_tables[ngram] = ngram_to_idx
            vocab_sizes[ngram] = actual_size + offset
    return ngram_to_idx_tables, ngram_to_idx_tables, vocab_sizes


def parse_ngram_to_size(ngram_to_size_str: str | None) -> dict[int, int]:
    if ngram_to_size_str is None:
        return None
    ngram_to_size = {}
    for entry in ngram_to_size_str.split(","):
        ngram, size = entry.split(":")
        ngram = int(ngram)
        size = int(size)
        ngram_to_size[ngram] = size
    return ngram_to_size


class NgramProcessor:
    def __init__(
        self,
        ngram_table_dir: str | None = None,
        ngram_to_size: dict[int, int] | None = None,
    ):
        if ngram_table_dir is None or ngram_to_size is None:
            raise ByteLatentError(
                "ngram_table_dir and ngram_to_size cannot be none if enable_byte_ngrams is True"
            )
        (
            self.ngram_to_idx_tables,
            self.idx_to_ngram_tables,
            self.ngram_vocab_sizes,
        ) = reload_tables(ngram_table_dir, ngram_to_size)
        # Lowest to highest ngram
        self.ngram_sizes = sorted(list(self.ngram_to_idx_tables.keys()))
        # Although the model might not use all the ngrams, we need the tokenizer
        # to produce ngram_ids such that index zero is the 2-gram, later on in
        # src.model.megabyte.Megabyte.forward
        assert self.ngram_sizes[0] == 2

    def encode_single_ngram_table(self, data: np.ndarray, n: int):
        """
        Return the n-grams of the input data for a given n
        numpy array with ids of shape data.shape
        """
        return get_byte_ngrams_ids(data, n, self.ngram_to_idx_tables[n], pad_value=0)

    def encode_token_ngrams(self, data: np.ndarray):
        """
        Return the n-grams of the input data.
        output shape: [ids with data.shape for n in self.ngram_sizes]
        """
        return [self.encode_single_ngram_table(data, n) for n in self.ngram_sizes]