library_name: keras-hub
Model Overview
BART encoder-decoder network.
This class implements a Transformer-based encoder-decoder model as described in "BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension".
The default constructor gives a fully customizable, randomly initialized BART
model with any number of layers, heads, and embedding dimensions. To load
preset architectures and weights, use the from_preset
constructor.
Disclaimer: Pre-trained models are provided on an "as is" basis, without warranties or conditions of any kind. The underlying model is provided by a third party and subject to a separate license, available here.
Arguments
- vocabulary_size: int. The size of the token vocabulary.
- num_layers: int. The number of transformer encoder layers and transformer decoder layers.
- num_heads: int. The number of attention heads for each transformer. The hidden size must be divisible by the number of attention heads.
- hidden_dim: int. The size of the transformer encoding and pooler layers.
- intermediate_dim: int. The output dimension of the first Dense layer in a two-layer feedforward network for each transformer.
- dropout: float. Dropout probability for the Transformer encoder.
- max_sequence_length: int. The maximum sequence length that this encoder
can consume. If None,
max_sequence_length
uses the value from sequence length. This determines the variable shape for positional embeddings.
Example Usage
import keras
import keras_hub
import numpy as np
Use generate()
to do text generation, given an input context.
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset("bart_large_en")
bart_lm.generate("The quick brown fox", max_length=30)
# Generate with batched inputs.
bart_lm.generate(["The quick brown fox", "The whale"], max_length=30)
Compile the generate()
function with a custom sampler.
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset("bart_large_en")
bart_lm.compile(sampler="greedy")
bart_lm.generate("The quick brown fox", max_length=30)
Use generate()
with encoder inputs and an incomplete decoder input (prompt).
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset("bart_large_en")
bart_lm.generate(
{
"encoder_text": "The quick brown fox",
"decoder_text": "The fast"
}
)
Use generate()
without preprocessing.
# Preprocessed inputs, with encoder inputs corresponding to
# "The quick brown fox", and the decoder inputs to "The fast". Use
# `"padding_mask"` to indicate values that should not be overridden.
prompt = {
"encoder_token_ids": np.array([[0, 133, 2119, 6219, 23602, 2, 1, 1]]),
"encoder_padding_mask": np.array(
[[True, True, True, True, True, True, False, False]]
),
"decoder_token_ids": np.array([[2, 0, 133, 1769, 2, 1, 1]]),
"decoder_padding_mask": np.array([[True, True, True, True, False, False]])
}
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset(
"bart_large_en",
preprocessor=None,
)
bart_lm.generate(prompt)
Call fit()
on a single batch.
features = {
"encoder_text": ["The quick brown fox jumped.", "I forgot my homework."],
"decoder_text": ["The fast hazel fox leapt.", "I forgot my assignment."]
}
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset("bart_large_en")
bart_lm.fit(x=features, batch_size=2)
Call fit()
without preprocessing.
x = {
"encoder_token_ids": np.array([[0, 133, 2119, 2, 1]] * 2),
"encoder_padding_mask": np.array([[1, 1, 1, 1, 0]] * 2),
"decoder_token_ids": np.array([[2, 0, 133, 1769, 2]] * 2),
"decoder_padding_mask": np.array([[1, 1, 1, 1, 1]] * 2),
}
y = np.array([[0, 133, 1769, 2, 1]] * 2)
sw = np.array([[1, 1, 1, 1, 0]] * 2)
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset(
"bart_large_en",
preprocessor=None,
)
bart_lm.fit(x=x, y=y, sample_weight=sw, batch_size=2)
Example Usage with Hugging Face URI
import keras
import keras_hub
import numpy as np
Use generate()
to do text generation, given an input context.
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset("hf://keras/bart_large_en")
bart_lm.generate("The quick brown fox", max_length=30)
# Generate with batched inputs.
bart_lm.generate(["The quick brown fox", "The whale"], max_length=30)
Compile the generate()
function with a custom sampler.
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset("hf://keras/bart_large_en")
bart_lm.compile(sampler="greedy")
bart_lm.generate("The quick brown fox", max_length=30)
Use generate()
with encoder inputs and an incomplete decoder input (prompt).
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset("hf://keras/bart_large_en")
bart_lm.generate(
{
"encoder_text": "The quick brown fox",
"decoder_text": "The fast"
}
)
Use generate()
without preprocessing.
# Preprocessed inputs, with encoder inputs corresponding to
# "The quick brown fox", and the decoder inputs to "The fast". Use
# `"padding_mask"` to indicate values that should not be overridden.
prompt = {
"encoder_token_ids": np.array([[0, 133, 2119, 6219, 23602, 2, 1, 1]]),
"encoder_padding_mask": np.array(
[[True, True, True, True, True, True, False, False]]
),
"decoder_token_ids": np.array([[2, 0, 133, 1769, 2, 1, 1]]),
"decoder_padding_mask": np.array([[True, True, True, True, False, False]])
}
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset(
"hf://keras/bart_large_en",
preprocessor=None,
)
bart_lm.generate(prompt)
Call fit()
on a single batch.
features = {
"encoder_text": ["The quick brown fox jumped.", "I forgot my homework."],
"decoder_text": ["The fast hazel fox leapt.", "I forgot my assignment."]
}
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset("hf://keras/bart_large_en")
bart_lm.fit(x=features, batch_size=2)
Call fit()
without preprocessing.
x = {
"encoder_token_ids": np.array([[0, 133, 2119, 2, 1]] * 2),
"encoder_padding_mask": np.array([[1, 1, 1, 1, 0]] * 2),
"decoder_token_ids": np.array([[2, 0, 133, 1769, 2]] * 2),
"decoder_padding_mask": np.array([[1, 1, 1, 1, 1]] * 2),
}
y = np.array([[0, 133, 1769, 2, 1]] * 2)
sw = np.array([[1, 1, 1, 1, 0]] * 2)
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset(
"hf://keras/bart_large_en",
preprocessor=None,
)
bart_lm.fit(x=x, y=y, sample_weight=sw, batch_size=2)