Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
import pickle | |
from pathlib import Path | |
import numpy as np | |
from bytelatent import ByteLatentError | |
LOOKUP_OFFSET = 4 | |
def apply_lookup_table_wrapper(ngram_to_idx: dict[tuple, int], lookup_offset=1): | |
""" | |
Wrapper function for applying the lookup table to each n-gram. | |
:param ngram: Array of numbers representing an n-gram. | |
:param lookup_table: Dictionary where keys are tuples (n-grams) and values are the desired outputs. | |
:param lookup_offset: Offset to add to the lookup result. | |
:return: The value associated with the n-gram tuple in the dictionary, or None if not found. | |
""" | |
def apply_lookup_table(ngram): | |
""" | |
Function to apply to each n-gram: converts it to a tuple and looks it up in a dictionary. | |
:param ngram: Array of numbers representing an n-gram. | |
:return: The value associated with the n-gram tuple in the dictionary, or None if not found. | |
""" | |
# Convert the n-gram to a tuple | |
ngram_tuple = tuple(ngram) | |
if ngram_tuple not in ngram_to_idx: | |
return 0 | |
else: | |
return ngram_to_idx[ngram_tuple] + lookup_offset | |
return apply_lookup_table | |
def get_byte_ngrams_ids( | |
byte_array: np.ndarray, n: int, ngram_to_idx: dict[tuple, int], pad_value=0 | |
): | |
""" | |
Generate n-grams from a 2D numpy array. | |
:param n: The length of each n-gram. | |
:param pad_value: The value used for padding of the byte values to maintain the same dimensions for the n-grams. | |
:return: A 2D numpy array where each element is the ID of an n-gram offset by LOOKUP_OFFSET. | |
""" | |
num_rows, num_cols = byte_array.shape | |
# Create an array to hold the padded version of the original array | |
padded_array = np.pad( | |
byte_array, ((0, 0), (n - 1, 0)), mode="constant", constant_values=pad_value | |
) | |
# Use stride tricks to avoid explicit looping | |
strided = np.lib.stride_tricks.as_strided | |
shape = (num_rows, num_cols, n) | |
strides = padded_array.strides[:2] + (padded_array.strides[1],) | |
ngrams = strided(padded_array, shape=shape, strides=strides) | |
ngram_ids = np.apply_along_axis( | |
apply_lookup_table_wrapper(ngram_to_idx, lookup_offset=LOOKUP_OFFSET), 2, ngrams | |
) | |
assert ngram_ids.shape == byte_array.shape | |
return ngram_ids | |
def reload_tables( | |
ngram_table_dir: str, ngram_to_size: dict[int, int], offset: int = LOOKUP_OFFSET | |
) -> tuple[dict[int, list], dict[tuple, int], dict[int, int]]: | |
""" | |
Reload lookup tables from a directory. Reload only the ngrams in the dictionary and per ngram, | |
only load up to the max specified size. Return the actual number of ngrams taken per ngram size. | |
""" | |
idx_to_ngram_tables = {} | |
ngram_to_idx_tables = {} | |
vocab_sizes = {} | |
for ngram, size in ngram_to_size.items(): | |
with open(Path(ngram_table_dir) / f"ngram-{ngram}.pickle", "rb") as f: | |
# These are already sorted by count | |
# Value: tuple of: count, ngram, dataset | |
ngram_data: list[tuple[tuple, tuple[int, int, str]]] = pickle.load(f)[ | |
"counts" | |
] | |
table = [ngram for ngram, _ in ngram_data][:size] | |
if len(table) != size: | |
raise ValueError( | |
f"Ngram table for {ngram}-gram is not large enough to get {size} ngrams, max size is {len(ngram_data)}" | |
) | |
ngram_to_idx = {ngram: idx for idx, ngram in enumerate(table)} | |
actual_size = len(table) | |
idx_to_ngram_tables[ngram] = table | |
ngram_to_idx_tables[ngram] = ngram_to_idx | |
vocab_sizes[ngram] = actual_size + offset | |
return ngram_to_idx_tables, ngram_to_idx_tables, vocab_sizes | |
def parse_ngram_to_size(ngram_to_size_str: str | None) -> dict[int, int]: | |
if ngram_to_size_str is None: | |
return None | |
ngram_to_size = {} | |
for entry in ngram_to_size_str.split(","): | |
ngram, size = entry.split(":") | |
ngram = int(ngram) | |
size = int(size) | |
ngram_to_size[ngram] = size | |
return ngram_to_size | |
class NgramProcessor: | |
def __init__( | |
self, | |
ngram_table_dir: str | None = None, | |
ngram_to_size: dict[int, int] | None = None, | |
): | |
if ngram_table_dir is None or ngram_to_size is None: | |
raise ByteLatentError( | |
"ngram_table_dir and ngram_to_size cannot be none if enable_byte_ngrams is True" | |
) | |
( | |
self.ngram_to_idx_tables, | |
self.idx_to_ngram_tables, | |
self.ngram_vocab_sizes, | |
) = reload_tables(ngram_table_dir, ngram_to_size) | |
# Lowest to highest ngram | |
self.ngram_sizes = sorted(list(self.ngram_to_idx_tables.keys())) | |
# Although the model might not use all the ngrams, we need the tokenizer | |
# to produce ngram_ids such that index zero is the 2-gram, later on in | |
# src.model.megabyte.Megabyte.forward | |
assert self.ngram_sizes[0] == 2 | |
def encode_single_ngram_table(self, data: np.ndarray, n: int): | |
""" | |
Return the n-grams of the input data for a given n | |
numpy array with ids of shape data.shape | |
""" | |
return get_byte_ngrams_ids(data, n, self.ngram_to_idx_tables[n], pad_value=0) | |
def encode_token_ngrams(self, data: np.ndarray): | |
""" | |
Return the n-grams of the input data. | |
output shape: [ids with data.shape for n in self.ngram_sizes] | |
""" | |
return [self.encode_single_ngram_table(data, n) for n in self.ngram_sizes] | |