Transformers documentation
Byte Lantet Transformer (BLT)
Byte Lantet Transformer (BLT)
Overview
The BLT model was proposed in Byte Latent Transformer: Patches Scale Better Than Tokens by Artidoro Pagnoni, Ram Pasunuru, Pedro Rodriguez, John Nguyen, Benjamin Muller, Margaret Li1, Chunting Zhou, Lili Yu, Jason Weston, Luke Zettlemoyer, Gargi Ghosh, Mike Lewis, Ari Holtzman†, Srinivasan Iyer. BLT is a byte-level LLM that achieves tokenization-level performance through entropy-based dynamic patching.
The abstract from the paper is the following:
We introduce the Byte Latent Transformer (BLT), a new byte-level LLM architecture that, for the first time, matches tokenization-based LLM performance at scale with significant improvements in inference efficiency and robustness. BLT encodes bytes into dynamically sized patches, which serve as the primary units of computation. Patches are segmented based on the entropy of the next byte, allocating more compute and model capacity where increased data complexity demands it. We present the first flop controlled scaling study of byte-level models up to 8B parameters and 4T training bytes. Our results demonstrate the feasibility of scaling models trained on raw bytes without a fixed vocabulary. Both training and inference efficiency improve due to dynamically selecting long patches when data is predictable, along with qualitative improvements on reasoning and long tail generalization. Overall, for fixed inference costs, BLT shows significantly better scaling than tokenization-based models, by simultaneously growing both patch and model size.
Usage Tips:
Dual Model Architecture: BLT consists of two separate trained models:
- Patcher (Entropy Model): A smaller transformer model that predicts byte-level entropy to determine patch boundaries and segment input.
- Main Transformer Model: The primary model that processes the patches through a Local Encoder, Global Transformer, and Local Decoder.
Dynamic Patching: The model uses entropy-based dynamic patching where:
- High-entropy regions (complex data) get shorter patches with more computational attention
- Low-entropy regions (predictable data) get longer patches for efficiency
- This allows the model to allocate compute resources where they’re most needed
Local Encoder: Processes byte sequences with cross-attention to patch embeddings
Global Transformer: Processes patch-level representations with full attention across patches
Local Decoder: Generates output with cross-attention back to the original byte sequence
Byte-Level Tokenizer: Unlike traditional tokenizers that use learned vocabularies, BLT’s tokenizer simply converts text to UTF-8 bytes and maps each byte to a token ID. There is no need for a vocabulary.
The model can be loaded via:
<hfoption id="AutoModel">import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-hf")
model = AutoModelForCausalLM.from_pretrained(
"itazap/blt-1b-hf",
device_map="auto",
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
prompt = "my name is"
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, use_cache=False
)
print(tokenizer.decode(generated_ids[0]))
This model was contributed by itazap. The original code can be found here.
BltConfig
class transformers.BltConfig
< source >( vocab_size = 260 max_position_embeddings = 4096 patch_in_forward = True patch_size = 4 patching_mode = 'entropy' patching_threshold = 1.335442066192627 patching_batch_size = 1 max_patch_length = None cross_attn_k = 2 encoder_hash_byte_group_size = None encoder_hash_byte_group_vocab = 500002 encoder_hash_byte_group_nb_functions = 1 patcher_config = None encoder_config = None decoder_config = None global_config = None tie_word_embeddings = False initializer_range = 0.02 rope_theta = 500000.0 rope_scaling = None **kwargs )
Parameters
- vocab_size (
int
, optional, defaults to 260) — Vocabulary size of the Blt model. Defines the number of different tokens that can be represented by theinputs_ids
passed when calling BltModel. - max_position_embeddings (
int
, optional, defaults to 4096) — The maximum sequence length that this model might ever be used with. - patch_in_forward (
bool
, optional, defaults toTrue
) — Whether to perform patching during the forward pass. - patch_size (
int
, optional, defaults to 4) — Size of the patches used in the patching mechanism. - patching_mode (
str
, optional, defaults to"entropy"
) — The mode used for patching, such as entropy-based patching. - patching_threshold (
float
, optional, defaults to 1.34) — Threshold value used for determining when to apply patches. - patching_batch_size (
int
, optional, defaults to 1) — Batch size used during the patching process. - max_patch_length (
int
, optional) — Maximum length of patches that can be generated. - cross_attn_k (
int
, optional, defaults to 2) — Number of cross-attention heads used in the model. - encoder_hash_byte_group_size (
list
, optional) — List of byte group sizes used in the encoder hash function. - encoder_hash_byte_group_vocab (
int
, optional, defaults to 500002) — Vocabulary size for the encoder hash byte groups. - encoder_hash_byte_group_nb_functions (
int
, optional, defaults to 1) — Number of hash functions used in the encoder byte grouping. - patcher_config (
BltPatcherConfig
, optional) — Configuration for the patcher component of the model. - encoder_config (
BltLocalEncoderConfig
, optional) — Configuration for the local encoder component of the model. - decoder_config (
BltLocalDecoderConfig
, optional) — Configuration for the local decoder component of the model. - global_config (
BltGlobalTransformerConfig
, optional) — Configuration for the global transformer component of the model. - tie_word_embeddings (
bool
, optional, defaults toFalse
) — Whether to tie weight embeddings. - initializer_range (
float
, optional, defaults to 0.02) — The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rope_theta (
float
, optional, defaults to 500000.0) — The base period of the RoPE embeddings. - rope_scaling (
dict
, optional) — Dictionary containing the RoPE scaling configuration.
This is the configuration class to store the configuration of a BltModel. It is used to instantiate a Blt model according to the specified arguments, defining the model architecture.
Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.
>>> from transformers import BltModel, BltConfig
>>> # Initializing a Blt configuration
>>> configuration = BltConfig()
>>> # Initializing a model from the configuration
>>> model = BltModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
Checkpoint: facebook/blt
forward
< source >( input_ids: typing.Optional[torch.LongTensor] = None patch_lengths: typing.Optional[torch.Tensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.LongTensor] = None past_key_values: typing.Optional[transformers.cache_utils.Cache] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None use_cache: typing.Optional[bool] = None cache_position: typing.Optional[torch.LongTensor] = None **kwargs: typing_extensions.Unpack[transformers.utils.generic.TransformersKwargs] )
BltForCausalLM
class transformers.BltForCausalLM
< source >( config: BltConfig )
Parameters
- config (BltConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
The Blt Text Model with a language modeling head on top.
This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)
This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
forward
< source >( input_ids: typing.Optional[torch.LongTensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.LongTensor] = None cross_attention_states: typing.Optional[torch.LongTensor] = None cross_attention_mask: typing.Optional[torch.LongTensor] = None full_text_row_masked_out_mask: typing.Optional[tuple[torch.Tensor, torch.Tensor]] = None past_key_values: typing.Union[transformers.cache_utils.Cache, list[torch.FloatTensor], NoneType] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None labels: typing.Optional[torch.LongTensor] = None use_cache: typing.Optional[bool] = None cache_position: typing.Optional[torch.LongTensor] = None logits_to_keep: typing.Union[int, torch.Tensor] = 0 **kwargs: typing_extensions.Unpack[transformers.utils.generic.TransformersKwargs] ) → transformers.modeling_outputs.CausalLMOutputWithPast or tuple(torch.FloatTensor)
Parameters
- input_ids (
torch.LongTensor
of shape(batch_size, sequence_length)
, optional) — Indices of input sequence tokens in the vocabulary. Padding will be ignored by default.Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() and PreTrainedTokenizer.call() for details.
- attention_mask (
torch.Tensor
of shape(batch_size, sequence_length)
, optional) — Mask to avoid performing attention on padding token indices. Mask values selected in[0, 1]
:- 1 for tokens that are not masked,
- 0 for tokens that are masked.
- position_ids (
torch.LongTensor
of shape(batch_size, sequence_length)
, optional) — Indices of positions of each input sequence tokens in the position embeddings. Selected in the range[0, config.n_positions - 1]
. - cross_attention_states (
torch.FloatTensor
, optional) — Output of the vision model, used for cross-attention. This tensor contains the processed image features that the language model will attend to. - cross_attention_mask (
torch.Tensor
of shape(batch_size, seq_length, max_num_images, max_num_tiles)
, optional) — Cross-attention mask to control the interaction between text tokens and image tiles. This 4D tensor defines which image tiles each text token should attend to.For each text token (in seq_length):
- 1 indicates the token should attend to the corresponding image tile
- 0 indicates the token should not attend to the corresponding image tile
- full_text_row_masked_out_mask (
tuple[torch.Tensor, torch.Tensor]
, optional) — A tuple containing two tensors that mask out rows in the cross-attention mechanism:- The first tensor has shape
(batch_size, 1, seq_length, 1)
and contains values of 0 or 1. A value of 0 indicates that the corresponding text token’s entire row in the cross-attention matrix should be masked out (all image tokens ignored). - The second tensor has the same shape and is used internally to apply the masking during the forward pass of cross-attention layers. This mask is derived from the cross_attention_mask and is used to handle cases where a text token should not attend to any image token.
- The first tensor has shape
- past_key_values (
Union[~cache_utils.Cache, list[torch.FloatTensor], NoneType]
) — Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in thepast_key_values
returned by the model at a previous stage of decoding, whenuse_cache=True
orconfig.use_cache=True
.Only Cache instance is allowed as input, see our kv cache guide. If no
past_key_values
are passed, DynamicCache will be initialized by default.The model will output the same cache format that is fed as input.
If
past_key_values
are used, the user is expected to input only unprocessedinput_ids
(those that don’t have their past key value states given to this model) of shape(batch_size, unprocessed_length)
instead of allinput_ids
of shape(batch_size, sequence_length)
. - inputs_embeds (
torch.FloatTensor
of shape(batch_size, sequence_length, hidden_size)
, optional) — Optionally, instead of passinginput_ids
you can choose to directly pass an embedded representation. This is useful if you want more control over how to convertinput_ids
indices into associated vectors than the model’s internal embedding lookup matrix. - labels (
torch.LongTensor
of shape(batch_size, sequence_length)
, optional) — Labels for computing the masked language modeling loss. Indices should either be in[0, ..., config.vocab_size]
or -100 (seeinput_ids
docstring). Tokens with indices set to-100
are ignored (masked), the loss is only computed for the tokens with labels in[0, ..., config.vocab_size]
. - use_cache (
bool
, optional) — If set toTrue
,past_key_values
key value states are returned and can be used to speed up decoding (seepast_key_values
). - cache_position (
torch.LongTensor
of shape(sequence_length)
, optional) — Indices depicting the position of the input sequence tokens in the sequence. Contrarily toposition_ids
, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length. - logits_to_keep (
Union[int, torch.Tensor]
, defaults to0
) — If anint
, compute logits for the lastlogits_to_keep
tokens. If0
, calculate logits for allinput_ids
(special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. If atorch.Tensor
, must be 1D corresponding to the indices to keep in the sequence length dimension. This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns
transformers.modeling_outputs.CausalLMOutputWithPast or tuple(torch.FloatTensor)
A transformers.modeling_outputs.CausalLMOutputWithPast or a tuple of
torch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various
elements depending on the configuration (BltConfig) and inputs.
-
loss (
torch.FloatTensor
of shape(1,)
, optional, returned whenlabels
is provided) — Language modeling loss (for next-token prediction). -
logits (
torch.FloatTensor
of shape(batch_size, sequence_length, config.vocab_size)
) — Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). -
past_key_values (
Cache
, optional, returned whenuse_cache=True
is passed or whenconfig.use_cache=True
) — It is a Cache instance. For more details, see our kv cache guide.Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
past_key_values
input) to speed up sequential decoding. -
hidden_states (
tuple(torch.FloatTensor)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) — Tuple oftorch.FloatTensor
(one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
-
attentions (
tuple(torch.FloatTensor)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple oftorch.FloatTensor
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
The BltForCausalLM forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
Example:
>>> from transformers import AutoTokenizer, BltForCausalLM
>>> model = BltForCausalLM.from_pretrained("Llama-3.2-11B-Vision")
>>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision")
>>> prompt = "If I had to write a haiku, it would be:"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6)
>>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
>>> print(result)
If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful.
I love the idea of snowflakes gently falling, each one