Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
import math | |
import os | |
import time | |
from collections import defaultdict | |
from contextlib import nullcontext | |
from enum import Enum | |
import torch | |
from pydantic import BaseModel | |
from torch.nn import functional as F | |
from bytelatent.distributed import get_local_rank | |
from bytelatent.entropy_model import load_entropy_model | |
# from src.slurm import get_local_rank | |
from bytelatent.tokenizers.blt_tokenizer import BPE_ID, OFFSET | |
from bytelatent.tokenizers.constants import BPE_ID, OFFSET | |
class PatchingModeEnum(str, Enum): | |
entropy = "entropy" | |
bpe = "bpe" | |
bpe_patcher = "bpe_patcher" | |
space = "space" | |
static = "static" | |
byte = "byte" | |
class PatcherArgs(BaseModel): | |
patching_mode: PatchingModeEnum = PatchingModeEnum.entropy | |
patching_device: str = "cuda" | |
entropy_model_checkpoint_dir: str | None = None | |
realtime_patching: bool = False | |
threshold: float = 1.335442066192627 | |
threshold_add: float | None = None | |
max_patch_length: int | None = None | |
patch_size: float = 4.5 | |
patching_batch_size: int = 1 | |
device: str = "cuda" | |
monotonicity: bool = False | |
log_time: bool = False | |
def build(self) -> "Patcher": | |
return Patcher(self) | |
def entropy(scores): | |
""" | |
scores: [bs, seq_len, vocab] | |
returns [bs, seq_len] | |
Computes the entropy for each token in the batch. | |
Note: uses natural log. | |
""" | |
log_probs = F.log_softmax(scores, dim=-1) | |
probs = torch.exp(log_probs) | |
p_log_p = log_probs * probs | |
entropy = -p_log_p.sum(dim=-1) | |
return entropy | |
def calculate_entropies( | |
tokens: torch.tensor, | |
entropy_model, | |
patching_batch_size, | |
device: str | None = None, | |
enable_grad: bool = False, | |
): | |
""" | |
tokens: 2D tensor of shape [batch_size, seq_len] | |
Return 2D tensor of shape [batch_size, seq_len] with entropies for each token. | |
Splits the tokens into chunks of size max_length and calculates entropies for each chunk. | |
Entropy model can be executed on cpu or gpu, specify either 'cuda' or 'cpu' in the device argument. | |
""" | |
grad_context = nullcontext() if enable_grad else torch.no_grad() | |
with grad_context: | |
entropies = [] | |
preds = [] | |
max_length = getattr(entropy_model, "max_length", 8192) | |
batch_numel = max_length * patching_batch_size | |
splits = torch.split(tokens.flatten(), batch_numel) | |
for split in splits: | |
pad_size = (max_length - (split.numel() % max_length)) % max_length | |
pad = torch.zeros( | |
pad_size, dtype=split.dtype, device=split.device, requires_grad=False | |
) | |
split = torch.cat((split, pad), dim=0) | |
split = split.reshape(-1, max_length) | |
if device is not None: | |
split = split.to(device) | |
# assert torch.all(split >= 0) and torch.all(split < 260) | |
pred = entropy_model(split) | |
pred = pred.reshape(-1, pred.shape[-1])[ | |
: split.numel() - pad_size, : | |
] # [batch_size * seq_len, vocab] | |
preds.append(pred) | |
pred_entropies = entropy(pred) | |
entropies.append(pred_entropies) | |
concat_entropies = torch.cat(entropies, dim=0) | |
concat_entropies = concat_entropies.reshape(tokens.shape) | |
concat_preds = torch.cat(preds, dim=0) | |
concat_preds = concat_preds.reshape(tokens.shape[0], -1) | |
return concat_entropies, concat_preds | |
def patch_start_mask_from_entropy_with_monotonicity(entropies, t): | |
""" | |
entropies: [bs, seq_len] torch tensor of entropies | |
t: threshold | |
returns [bs, seq_len] mask where True indicates the start of a patch | |
""" | |
bs, seq_len = entropies.shape | |
if seq_len == 0: | |
return entropies > t | |
mask = torch.zeros_like(entropies, dtype=torch.bool) | |
mask[:, 0] = True | |
# Calculate differences between consecutive elements along the sequence length | |
differences = entropies[:, 1:] - entropies[:, :-1] | |
# Calculate conditions for all elements except the first one in each sequence | |
condition = differences > t | |
# Update the mask based on the condition | |
mask[:, 1:] = condition | |
return mask | |
def patch_start_mask_global_and_monotonicity(entropies, t, t_add=0): | |
""" | |
entropies: [bs, seq_len] torch tensor of entropies | |
t: threshold | |
returns [bs, seq_len] mask where True indicates the start of a patch | |
""" | |
bs, seq_len = entropies.shape | |
if seq_len == 0: | |
return entropies > t | |
mask = torch.zeros_like(entropies, dtype=torch.bool) | |
mask[:, 0] = True | |
# Calculate differences between consecutive elements along the sequence length | |
differences = entropies[:, 1:] - entropies[:, :-1] | |
# Calculate conditions for all elements except the first one in each sequence | |
condition = (differences > t_add) & (entropies[:, 1:] > t) & (~mask[:, :-1]) | |
# Update the mask based on the condition | |
mask[:, 1:] = condition | |
return mask | |
def patch_start_ids_from_patch_start_mask(patch_start_mask): | |
bs, trunc_seq_len = patch_start_mask.shape | |
max_patches = patch_start_mask.sum(dim=1).max() | |
if max_patches == 0: | |
patch_start_ids = torch.full( | |
(bs, trunc_seq_len), | |
trunc_seq_len, | |
dtype=torch.long, | |
device=patch_start_mask.device, | |
) | |
else: | |
patch_ids = ( | |
torch.arange(trunc_seq_len, device=patch_start_mask.device) | |
.unsqueeze(0) | |
.repeat(bs, 1) | |
) | |
extra_patch_ids = torch.full( | |
(bs, trunc_seq_len), | |
trunc_seq_len, | |
dtype=torch.long, | |
device=patch_start_mask.device, | |
) | |
all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1) | |
patch_start_mask_padded = torch.cat( | |
(patch_start_mask, ~patch_start_mask), dim=1 | |
) | |
patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape( | |
bs, trunc_seq_len | |
)[:, :max_patches] | |
return patch_start_ids | |
def check_non_zero_after_zero(tensor): | |
zero_mask = tensor == 0 | |
shifted_mask = torch.cat( | |
[ | |
torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device), | |
zero_mask[:, :-1], | |
], | |
dim=1, | |
) | |
non_zero_after_zero = (tensor != 0) & shifted_mask | |
return non_zero_after_zero.any() | |
def patch_lengths_from_start_ids(patch_start_ids, seq_len): | |
""" | |
Calculate patch lengths from start ids. | |
start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then | |
the rest are filled to the seq len. | |
seq_len: ex: 7 length of the sequence | |
returns the patch lengths: | |
[1, 6] for the above example. | |
""" | |
last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1) | |
patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1) | |
patch_lengths = patch_end_ids - patch_start_ids + 1 | |
assert torch.all(patch_lengths >= 0), f"{patch_lengths}" | |
assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}" | |
return patch_lengths | |
def find_space_patch_start_ids(tokens): | |
bs, seq_len = tokens.shape | |
tokens_no_offset = tokens - OFFSET | |
patch_end_mask = ( | |
(tokens_no_offset < ord("0")) | |
| ((ord("9") < tokens_no_offset) & (tokens_no_offset < ord("A"))) | |
| ((ord("Z") < tokens_no_offset) & (tokens_no_offset < ord("a"))) | |
| ((ord("z") < tokens_no_offset) & (tokens_no_offset < 0b1000_0000)) | |
| (0b1100_0000 <= tokens_no_offset) | |
) | |
patch_end_mask[:, 1:] &= patch_end_mask[:, :-1].bitwise_not() | |
patch_end_mask |= tokens < OFFSET | |
patch_start_mask = torch.cat( | |
[ | |
torch.tensor([1, 1], device=tokens.device, dtype=torch.bool) | |
.unsqueeze(0) | |
.repeat(bs, 1), | |
patch_end_mask[:, 1:], | |
], | |
dim=1, | |
) | |
max_patches = patch_start_mask.sum(dim=1).max() | |
patch_ids = ( | |
torch.arange(seq_len + 1, device=tokens.device).unsqueeze(0).repeat(bs, 1) | |
) | |
extra_patch_ids = torch.full( | |
(bs, seq_len + 1), seq_len + 1, dtype=torch.long, device=tokens.device | |
) | |
all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1) | |
patch_start_mask_padded = torch.cat((patch_start_mask, ~patch_start_mask), dim=1) | |
patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(bs, -1)[ | |
:, :max_patches | |
] | |
return patch_start_ids | |
def to_device(entropy_model, device=None): | |
if device == "cuda": | |
rank = get_local_rank() | |
device = f"cuda:{rank}" | |
entropy_model = entropy_model.to(device) | |
return entropy_model, device | |
def model_pred_to_bpe_patching_pred(pred): | |
_, indices = torch.max(pred, dim=1) | |
return indices == BPE_ID | |
def apply_bpe_patcher(tokens, bpe_patcher, patching_batch_size, device=None): | |
assert tokens.device == torch.device( | |
"cpu" | |
), f"{tokens.device} != cpu expects tokens to be on cpu" | |
with torch.no_grad(): | |
bpe_patcher_device, device = to_device( | |
bpe_patcher, device | |
) # Get entropy model to right rank device. | |
bpe_patching_mask = [] | |
max_length = getattr(bpe_patcher, "max_length", 8192) | |
batch_numel = max_length * patching_batch_size | |
splits = torch.split(tokens.flatten(), batch_numel) | |
for split in splits: | |
pad_size = (max_length - (split.numel() % max_length)) % max_length | |
pad = torch.zeros( | |
pad_size, dtype=split.dtype, device=split.device, requires_grad=False | |
) | |
split = torch.cat((split, pad), dim=0) | |
split = split.reshape(-1, max_length).to(device) | |
assert torch.all(split >= 0) and torch.all(split < 260) | |
pred = bpe_patcher_device(split) | |
pred_cpu = pred[0].cpu() | |
pred_cpu = pred_cpu.reshape(-1, pred_cpu.shape[-1])[ | |
: split.numel() - pad_size, : | |
] # [batch_size * seq_len, vocab] | |
bpe_patching_pred = model_pred_to_bpe_patching_pred(pred_cpu) | |
bpe_patching_mask.append(bpe_patching_pred) | |
bpe_patching_mask = torch.cat(bpe_patching_mask, dim=0) | |
bpe_patching_mask = bpe_patching_mask.reshape(tokens.shape) | |
return bpe_patching_mask | |
def find_bpe_patcher_patch_start_ids( | |
tokens, bpe_patcher, patching_batch_size, device=None, include_next_token=True | |
): | |
bs, seq_len = tokens.shape | |
first_ids = ( | |
torch.tensor([0, 1], dtype=torch.long, device=tokens.device) | |
.unsqueeze(0) | |
.repeat(bs, 1) | |
) | |
preds_truncation_len = first_ids.shape[1] | |
token_input = tokens[:, 1:] if include_next_token else tokens[:, 1:-1] | |
if token_input.shape[1] >= 1: | |
patch_start_mask = apply_bpe_patcher( | |
token_input, bpe_patcher, patching_batch_size, device | |
) | |
assert ( | |
patch_start_mask.shape[1] | |
== tokens.shape[1] + include_next_token - preds_truncation_len | |
), f"{patch_start_mask.shape[1]} != {tokens.shape[1] + include_next_token - preds_truncation_len}" | |
patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask) | |
patch_start_ids = torch.cat( | |
(first_ids, patch_start_ids + preds_truncation_len), dim=1 | |
) | |
else: | |
patch_start_ids = first_ids | |
return patch_start_ids | |
def find_entropy_patch_start_ids( | |
entropies, | |
patch_size=None, | |
threshold=None, | |
threshold_add=None, | |
monotonicity=False, | |
include_next_token=True, | |
): | |
""" | |
Use entropies to find the start ids of each patch. | |
Use patch_size or threshold to figure out the total number of patches to allocate. | |
When threshold is not None the number of patches is not constant between | |
different sequences, but patches can be identified incrementally rather than | |
decided globally using the entire sequence. | |
""" | |
bs, seq_len = entropies.shape[:2] | |
first_ids = ( | |
torch.tensor([0, 1], dtype=torch.long, device=entropies.device) | |
.unsqueeze(0) | |
.repeat(bs, 1) | |
) | |
preds_truncation_len = first_ids.shape[ | |
1 | |
] # remove the first preds because they will be start of patches. | |
entropies = entropies[:, 1:] | |
if threshold is None: | |
num_patches = seq_len // patch_size | |
patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices | |
patch_start_ids = patch_start_ids.sort(dim=1).values | |
else: | |
# Assumes that there is at least one token going over the threshold | |
if monotonicity: | |
patch_start_mask = patch_start_mask_from_entropy_with_monotonicity( | |
entropies, threshold | |
) | |
elif threshold_add is not None and threshold is not None: | |
patch_start_mask = patch_start_mask_global_and_monotonicity( | |
entropies, threshold, threshold_add | |
) | |
else: | |
patch_start_mask = entropies > threshold | |
if not include_next_token: | |
patch_start_mask = patch_start_mask[:, :-1] | |
# patch_start_mask[1:] |= tokens[:-1] < OFFSET | |
patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask) | |
patch_start_ids = torch.cat( | |
(first_ids, patch_start_ids + preds_truncation_len), dim=1 | |
) | |
return patch_start_ids | |
def rightpad(seq, pad_id, max_len): | |
return seq + [pad_id] * (max_len - len(seq)) | |
def find_bpe_delim_patch_start_ids(tokens, delim): | |
ids = (tokens[:, :-1] == delim).nonzero(as_tuple=False) | |
out = [[0, 1] for _ in range(tokens.shape[0])] | |
for x, y in ids: | |
# start is at delim + 1, delim should be the last element in the patch. | |
out[x.item()].append(y.item() + 1) | |
max_len = max([len(elt) for elt in out]) | |
out = [rightpad(elt, tokens.shape[1], max_len) for elt in out] | |
patch_start_ids = torch.tensor(out, dtype=tokens.dtype, device=tokens.device) | |
return patch_start_ids | |
def find_lookup_table_start_mask( | |
tokens: torch.Tensor, lookup_table: torch.Tensor, include_next_token=True | |
): | |
window_size = lookup_table.ndim | |
# Unfold the tensor to get sliding windows | |
unfolded = tokens.unfold(1, window_size, 1) | |
# Gather indices for each dimension | |
indices = [unfolded[..., i] for i in range(window_size)] | |
# Access the lookup table using the gathered indices | |
result = lookup_table[indices] | |
return result | |
def find_lookup_table_patch_start_ids( | |
tokens: torch.Tensor, lookup_table: torch.Tensor, include_next_token=True | |
): | |
bs, seq_len = tokens.shape | |
first_ids = ( | |
torch.tensor([0, 1], dtype=torch.long, device=tokens.device) | |
.unsqueeze(0) | |
.repeat(bs, 1) | |
) | |
preds_truncation_len = first_ids.shape[1] | |
window_size = lookup_table.ndim | |
assert window_size == 2, f"{window_size} != 2" | |
# output dimensions: token_input shape - window_size + 1 --> we want first ids + this = tokens shape + 1 if next token otherwise just token shape | |
token_input = ( | |
tokens if include_next_token else tokens[:, : -preds_truncation_len + 1] | |
) | |
if token_input.shape[1] >= window_size: | |
patch_start_mask = find_lookup_table_start_mask( | |
token_input, lookup_table, include_next_token | |
) | |
assert ( | |
patch_start_mask.shape[1] | |
== tokens.shape[1] + include_next_token - preds_truncation_len | |
), f"{patch_start_mask.shape[1]} != {tokens.shape[1] + include_next_token - preds_truncation_len}" | |
patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask) | |
patch_start_ids = torch.cat( | |
(first_ids, patch_start_ids + preds_truncation_len), dim=1 | |
) | |
else: | |
patch_start_ids = first_ids | |
return patch_start_ids | |
def split_large_numbers(lst, m): | |
new_lst = [] | |
for i in lst: | |
if i > m: | |
while i > m: | |
new_lst.append(m) | |
i -= m | |
new_lst.append(i) | |
else: | |
new_lst.append(i) | |
assert sum(new_lst) == sum(lst), f"{sum(new_lst)} != {sum(lst)}" | |
return new_lst | |
class Patcher: | |
def __init__(self, patcher_args: PatcherArgs): | |
self.patcher_args = patcher_args | |
self.patching_mode = patcher_args.patching_mode | |
self.realtime_patching = patcher_args.realtime_patching | |
if self.realtime_patching: | |
assert ( | |
patcher_args.entropy_model_checkpoint_dir is not None | |
), "Cannot require realtime patching without an entropy model checkpoint" | |
maybe_consolidated = os.path.join( | |
patcher_args.entropy_model_checkpoint_dir, | |
"consolidated/consolidated.pth", | |
) | |
if os.path.exists(maybe_consolidated): | |
state_path = maybe_consolidated | |
else: | |
state_path = os.path.join( | |
patcher_args.entropy_model_checkpoint_dir, "consolidated.pth" | |
) | |
entropy_model = load_entropy_model( | |
patcher_args.entropy_model_checkpoint_dir, | |
state_path, | |
) | |
entropy_model, _ = to_device(entropy_model, patcher_args.patching_device) | |
self.entropy_model = entropy_model | |
else: | |
self.entropy_model = None | |
self.threshold = patcher_args.threshold | |
self.threshold_add = patcher_args.threshold_add | |
self.max_patch_length = patcher_args.max_patch_length | |
self.patch_size = patcher_args.patch_size | |
self.patching_batch_size = patcher_args.patching_batch_size | |
self.device = patcher_args.device | |
self.monotonicity = patcher_args.monotonicity | |
self.log_time = patcher_args.log_time | |
if self.log_time: | |
self.log = defaultdict(float) | |
def patch( | |
self, | |
tokens: torch.Tensor, | |
include_next_token: bool = False, | |
preds: torch.Tensor | None = None, | |
entropies: torch.Tensor | None = None, | |
threshold: float = None, | |
) -> torch.Tensor: | |
""" | |
tokens: 2D tensor of shape [batch_size, seq_len] that needs to be patched | |
Returns patch lengths and optionally scores associated with the tokens (i.e. entropies, logprobs etc.) | |
-> output tensor: [batch_size, max_num_patches] | |
each tensor is processed independently and gets right padded with zeros. | |
Patching with the following modes: | |
1. patching_mode = None: static patch size | |
2. patching_mode = "entropy": | |
calculate entropy of each token, allocate patches so that the total | |
number of patches is the same as static patching but choose to begin | |
patches on tokens where the model is most uncertain (highest entropy). | |
When threshold is provided, it uses the threshold to decide when to | |
start a new patch. | |
3. patching_mode = "space": | |
use space like tokens to define the patches. | |
4. patching_mode = "bpe": | |
use bpe delim tokens to define the patches. | |
To correctly patch the last token, it may be necessary to include the next token in the patch | |
lengths calculations. This is controlled by the include_next_token argument. | |
""" | |
bs, seq_len = tokens.shape | |
seq_len_next_tok = seq_len + 1 if include_next_token else seq_len | |
scores = None | |
# STATIC | |
if self.patching_mode == PatchingModeEnum.static: | |
patch_lengths = torch.zeros( | |
(bs, math.ceil(seq_len_next_tok / self.patch_size)), | |
dtype=tokens.dtype, | |
device=tokens.device, | |
).fill_(self.patch_size) | |
if seq_len_next_tok % self.patch_size != 0: | |
patch_lengths[:, -1] = seq_len_next_tok % self.patch_size | |
elif self.patching_mode == PatchingModeEnum.byte: | |
patch_lengths = torch.ones( | |
(bs, seq_len_next_tok), dtype=tokens.dtype, device=tokens.device | |
) | |
# ENTROPY | |
elif self.patching_mode == PatchingModeEnum.entropy: | |
if self.log_time: | |
s = time.time() | |
if entropies is not None: | |
scores = entropies.to(dtype=torch.float32) | |
elif preds is not None: | |
scores = entropy(preds) | |
else: | |
start_entropies = time.time() | |
scores, _ = calculate_entropies( | |
tokens, | |
self.entropy_model, | |
self.patching_batch_size, | |
self.device, | |
) | |
if self.log_time: | |
self.log["calculate_entropies"] += time.time() - s | |
s = time.time() | |
patch_start_ids = find_entropy_patch_start_ids( | |
scores, | |
self.patch_size, | |
include_next_token=include_next_token, | |
threshold=threshold if threshold is not None else self.threshold, | |
threshold_add=self.threshold_add, | |
monotonicity=self.monotonicity, | |
) | |
if self.log_time: | |
self.log["find_entropy_patch_start_ids"] += time.time() - s | |
s = time.time() | |
patch_lengths = patch_lengths_from_start_ids( | |
patch_start_ids, seq_len_next_tok | |
) | |
if self.log_time: | |
self.log["patch_lengths_from_start_ids"] += time.time() - s | |
s = time.time() | |
# BPE | |
elif self.patching_mode == PatchingModeEnum.bpe: | |
patch_start_ids = find_bpe_delim_patch_start_ids(tokens, delim=BPE_ID) | |
patch_lengths = patch_lengths_from_start_ids( | |
patch_start_ids, seq_len_next_tok | |
) | |
elif self.patching_mode == PatchingModeEnum.bpe_patcher: | |
patch_start_ids = find_bpe_patcher_patch_start_ids( | |
tokens, | |
self.entropy_model, | |
self.patching_batch_size, | |
self.device, | |
include_next_token, | |
) | |
patch_lengths = patch_lengths_from_start_ids( | |
patch_start_ids, seq_len_next_tok | |
) | |
# SPACE | |
elif self.patching_mode == PatchingModeEnum.space: | |
patch_start_ids = find_space_patch_start_ids(tokens) | |
patch_lengths = patch_lengths_from_start_ids( | |
patch_start_ids, seq_len_next_tok | |
) | |
else: | |
raise NotImplementedError(f"self.patching_mode {self.patching_mode}") | |
# Apply any processing to patch lengths | |
if self.max_patch_length is not None: | |
# TODO: avoid going back to a list here. | |
patch_lengths = [ | |
split_large_numbers(pl, self.max_patch_length) | |
for pl in patch_lengths.tolist() | |
] | |
max_len = max([len(pl) for pl in patch_lengths]) | |
patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths] | |
patch_lengths = torch.tensor( | |
patch_lengths, dtype=tokens.dtype, device=tokens.device | |
) | |
assert not check_non_zero_after_zero(patch_lengths) | |
# Find the last non-zero column index using argmax on a reversed version of the tensor | |
last_non_zero_col_reversed = ( | |
(patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min() | |
) | |
# Slice the tensor up to the last non-zero column | |
patch_lengths = patch_lengths[ | |
:, : patch_lengths.shape[1] - last_non_zero_col_reversed | |
] | |
assert ( | |
torch.sum(patch_lengths) | |
== tokens.numel() + include_next_token * tokens.shape[0] | |
), f"{torch.sum(patch_lengths)} != {tokens.numel() + include_next_token * tokens.shape[0]}" | |
if self.log_time: | |
self.log["postprocessing_patch_lengths"] += time.time() - s | |
self.log["tokens"] += patch_lengths.sum().item() | |
return patch_lengths, scores | |