DBRX
Overview
DBRX is a transformer-based decoder-only large language model (LLM) that was trained using next-token prediction. It uses a fine-grained mixture-of-experts (MoE) architecture with 132B total parameters of which 36B parameters are active on any input. It was pre-trained on 12T tokens of text and code data. Compared to other open MoE models like Mixtral-8x7B and Grok-1, DBRX is fine-grained, meaning it uses a larger number of smaller experts. DBRX has 16 experts and chooses 4, while Mixtral-8x7B and Grok-1 have 8 experts and choose 2. This provides 65x more possible combinations of experts and we found that this improves model quality. DBRX uses rotary position encodings (RoPE), gated linear units (GLU), and grouped query attention (GQA). It is a BPE based model and uses the GPT-4 tokenizer as described in the tiktoken repository. We made these choices based on exhaustive evaluation and scaling experiments.
DBRX was pretrained on 12T tokens of carefully curated data and a maximum context length of 32K tokens. We estimate that this data is at least 2x better token-for-token than the data we used to pretrain the MPT family of models. This new dataset was developed using the full suite of Databricks tools, including Apache Sparkβ’ and Databricks notebooks for data processing, and Unity Catalog for data management and governance. We used curriculum learning for pretraining, changing the data mix during training in ways we found to substantially improve model quality.
More detailed information about DBRX Instruct and DBRX Base can be found in our technical blog post.
This model was contributed by eitan-turok and abhi-db. The original code can be found here, though this may not be up to date.
Usage Examples
The generate()
method can be used to generate text using DBRX. You can generate using the standard attention implementation, flash-attention, and the PyTorch scaled dot product attention. The last two attention implementations give speed ups.
from transformers import DbrxForCausalLM, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct", token="YOUR_HF_TOKEN")
model = DbrxForCausalLM.from_pretrained(
"databricks/dbrx-instruct",
device_map="auto",
torch_dtype=torch.bfloat16,
token="YOUR_HF_TOKEN",
)
input_text = "What does it take to build a great LLM?"
messages = [{"role": "user", "content": input_text}]
input_ids = tokenizer.apply_chat_template(messages, return_dict=True, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids, max_new_tokens=200)
print(tokenizer.decode(outputs[0]))
If you have flash-attention installed (pip install flash-attn
), it is possible to generate faster. (The HuggingFace documentation for flash-attention can be found here.)
from transformers import DbrxForCausalLM, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct", token="YOUR_HF_TOKEN")
model = DbrxForCausalLM.from_pretrained(
"databricks/dbrx-instruct",
device_map="auto",
torch_dtype=torch.bfloat16,
token="YOUR_HF_TOKEN",
attn_implementation="flash_attention_2",
)
input_text = "What does it take to build a great LLM?"
messages = [{"role": "user", "content": input_text}]
input_ids = tokenizer.apply_chat_template(messages, return_dict=True, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids, max_new_tokens=200)
print(tokenizer.decode(outputs[0]))
You can also generate faster using the PyTorch scaled dot product attention. (The HuggingFace documentation for scaled dot product attention can be found here.)
from transformers import DbrxForCausalLM, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct", token="YOUR_HF_TOKEN")
model = DbrxForCausalLM.from_pretrained(
"databricks/dbrx-instruct",
device_map="auto",
torch_dtype=torch.bfloat16,
token="YOUR_HF_TOKEN",
attn_implementation="sdpa",
)
input_text = "What does it take to build a great LLM?"
messages = [{"role": "user", "content": input_text}]
input_ids = tokenizer.apply_chat_template(messages, return_dict=True, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids, max_new_tokens=200)
print(tokenizer.decode(outputs[0]))
DbrxConfig
class transformers.DbrxConfig
< source >( d_model: int = 2048 n_heads: int = 16 n_layers: int = 24 max_seq_len: int = 2048 vocab_size: int = 32000 resid_pdrop: float = 0.0 emb_pdrop: float = 0.0 attn_config: Optional = None ffn_config: Optional = None use_cache: bool = True initializer_range: float = 0.02 output_router_logits: bool = False **kwargs: Any )
Parameters
- d_model (
int
, optional, defaults to 2048) — Dimensionality of the embeddings and hidden states. - n_heads (
int
, optional, defaults to 16) — Number of attention heads for each attention layer in the Transformer encoder. - n_layers (
int
, optional, defaults to 24) — Number of hidden layers in the Transformer encoder. - max_seq_len (
int
, optional, defaults to 2048) — The maximum sequence length of the model. - vocab_size (
int
, optional, defaults to 32000) — Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by theinputs_ids
passed when calling DbrxModel. - resid_pdrop (
float
, optional, defaults to 0.0) — The dropout probability applied to the attention output before combining with residual. - emb_pdrop (
float
, optional, defaults to 0.0) — The dropout probability for the embedding layer. - attn_config (
dict
, optional) — A dictionary used to configure the model’s attention module. - ffn_config (
dict
, optional) — A dictionary used to configure the model’s FFN module. - use_cache (
bool
, optional, defaults toTrue
) — Whether or not the model should return the last key/values attentions (not used by all models). - initializer_range (
float
, optional, defaults to 0.02) — The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - output_router_logits (
bool
, optional, defaults toFalse
) — Whether or not the router logits should be returned by the model. Enabling this will also allow the model to output the auxiliary loss. See here for more details.
This is the configuration class to store the configuration of a DbrxModel. It is used to instantiate a Dbrx model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a different configuration to that of the databricks/dbrx-instruct architecture.
Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.
Example:
>>> from transformers import DbrxConfig, DbrxModel
>>> # Initializing a Dbrx configuration
>>> configuration = DbrxConfig(n_layers=2, d_model=256, n_heads=8, vocab_size=128)
>>> # Initializing a model (with random weights) from the configuration
>>> model = DbrxModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
DbrxModel
class transformers.DbrxModel
< source >( config: DbrxConfig )
Parameters
- config (DbrxConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
- config (DbrxConfig) — Model configuration class with all parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
The bare DBRX Model outputting raw hidden-states without any specific head on top. This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)
This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
Transformer decoder consisting of config.num_hidden_layers. Each layer is a DbrxBlock
layer.
forward
< source >( input_ids: Optional = None attention_mask: Optional = None position_ids: Optional = None past_key_values: Optional = None inputs_embeds: Optional = None use_cache: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None output_router_logits: Optional = None return_dict: Optional = None cache_position: Optional = None )
Parameters
- input_ids (
torch.LongTensor
of shape(batch_size, sequence_length)
) — Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it.Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() and PreTrainedTokenizer.call() for details.
- attention_mask (
torch.Tensor
of shape(batch_size, sequence_length)
, optional) — Mask to avoid performing attention on padding token indices. Mask values selected in[0, 1]
:- 1 for tokens that are not masked,
- 0 for tokens that are masked.
Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() and PreTrainedTokenizer.call() for details.
If
past_key_values
is used, optionally only the lastdecoder_input_ids
have to be input (seepast_key_values
).If you want to change padding behavior, you should read
modeling_opt._prepare_decoder_attention_mask
and modify to your needs. See diagram 1 in the paper for more information on the default strategy.- 1 indicates the head is not masked,
- 0 indicates the head is masked.
- position_ids (
torch.LongTensor
of shape(batch_size, sequence_length)
, optional) — Indices of positions of each input sequence tokens in the position embeddings. Selected in the range[0, config.n_positions - 1]
. - past_key_values (
Cache
ortuple(tuple(torch.FloatTensor))
, optional) — Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in thepast_key_values
returned by the model at a previous stage of decoding, whenuse_cache=True
orconfig.use_cache=True
.Two formats are allowed:
- a Cache instance;
- Tuple of
tuple(torch.FloatTensor)
of lengthconfig.n_layers
, with each tuple having 2 tensors of shape(batch_size, num_heads, sequence_length, embed_size_per_head)
). This is also known as the legacy cache format.
The model will output the same cache format that is fed as input. If no
past_key_values
are passed, the legacy cache format will be returned.If
past_key_values
are used, the user can optionally input only the lastinput_ids
(those that don’t have their past key value states given to this model) of shape(batch_size, 1)
instead of allinput_ids
of shape(batch_size, sequence_length)
. - inputs_embeds (
torch.FloatTensor
of shape(batch_size, sequence_length, hidden_size)
, optional) — Optionally, instead of passinginput_ids
you can choose to directly pass an embedded representation. This is useful if you want more control over how to convertinput_ids
indices into associated vectors than the model’s internal embedding lookup matrix. - use_cache (
bool
, optional) — If set toTrue
,past_key_values
key value states are returned and can be used to speed up decoding (seepast_key_values
). - output_attentions (
bool
, optional) — Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returned tensors for more detail. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail. - output_router_logits (
bool
, optional) — Whether or not to return the logits of all the routers. They are useful for computing the router loss, and should not be returned during inference. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple. - cache_position (
torch.LongTensor
of shape(sequence_length)
, optional) — Indices depicting the position of the input sequence tokens in the sequence. Contrarily toposition_ids
, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length.
The DbrxModel forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
DbrxForCausalLM
class transformers.DbrxForCausalLM
< source >( config: DbrxConfig )
Parameters
- config (DbrxConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
The DBRX Model transformer for causal language modeling. This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)
This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
forward
< source >( input_ids: Optional = None attention_mask: Optional = None position_ids: Optional = None past_key_values: Optional = None inputs_embeds: Optional = None labels: Optional = None use_cache: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None output_router_logits: Optional = None return_dict: Optional = None cache_position: Optional = None ) β transformers.modeling_outputs.MoeCausalLMOutputWithPast
or tuple(torch.FloatTensor)
Parameters
- input_ids (
torch.LongTensor
of shape(batch_size, sequence_length)
) — Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it.
Returns
transformers.modeling_outputs.MoeCausalLMOutputWithPast
or tuple(torch.FloatTensor)
A transformers.modeling_outputs.MoeCausalLMOutputWithPast
or a tuple of
torch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various
elements depending on the configuration (DbrxConfig) and inputs.
-
loss (
torch.FloatTensor
of shape(1,)
, optional, returned whenlabels
is provided) β Language modeling loss (for next-token prediction). -
logits (
torch.FloatTensor
of shape(batch_size, sequence_length, config.vocab_size)
) β Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). -
aux_loss (
torch.FloatTensor
, optional, returned whenlabels
is provided) β aux_loss for the sparse modules. -
router_logits (
tuple(torch.FloatTensor)
, optional, returned whenoutput_router_probs=True
andconfig.add_router_probs=True
is passed or whenconfig.output_router_probs=True
) β Tuple oftorch.FloatTensor
(one for each layer) of shape(batch_size, sequence_length, num_experts)
.Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary loss for Mixture of Experts models.
-
past_key_values (
tuple(tuple(torch.FloatTensor))
, optional, returned whenuse_cache=True
is passed or whenconfig.use_cache=True
) β Tuple oftuple(torch.FloatTensor)
of lengthconfig.n_layers
, with each tuple having 2 tensors of shape(batch_size, num_heads, sequence_length, embed_size_per_head)
)Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
past_key_values
input) to speed up sequential decoding. -
hidden_states (
tuple(torch.FloatTensor)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) β Tuple oftorch.FloatTensor
(one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
-
attentions (
tuple(torch.FloatTensor)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) β Tuple oftorch.FloatTensor
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
The DbrxForCausalLM forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
Forward function for causal language modeling.
Example:
>> from transformers import AutoTokenizer, DbrxForCausalLM
>> model = DbrxForCausalLM.from_pretrained("databricks/dbrx-instruct")
>> tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct")
>> prompt = "Hey, are you conscious? Can you talk to me?"
>> inputs = tokenizer(prompt, return_tensors="pt")
>> # Generate
>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."