# Installation

In [1]:
!pip install git+https://github.com/huggingface/transformers/
!pip install git+https://github.com/google/flax

Collecting git+https://github.com/huggingface/transformers/
 Cloning https://github.com/huggingface/transformers/ to /tmp/pip-req-build-oxejx1op
 Running command git clone -q https://github.com/huggingface/transformers/ /tmp/pip-req-build-oxejx1op
 Installing build dependencies ... [?25l[?25hdone
 Getting requirements to build wheel ... [?25l[?25hdone
 Preparing wheel metadata ... [?25l[?25hdone
Building wheels for collected packages: transformers
 Building wheel for transformers (PEP 517) ... [?25l[?25hdone
 Created wheel for transformers: filename=transformers-4.9.0.dev0-cp37-none-any.whl size=2582229 sha256=249c593273ccca3027c6427d2c6fd749a89f21d722d628d97eb438a2cf3185a8
 Stored in directory: /tmp/pip-ephem-wheel-cache-l2rqt1b7/wheels/61/69/33/974fccec4d0ab5feee9fe83bd93e680d269a805be9ede5ec60
Successfully built transformers
Collecting git+https://github.com/google/flax
 Cloning https://github.com/google/flax to /tmp/pip-req-build-rt9g1_wx
 Running command git clone -q https

In [2]:
%load_ext autoreload
%autoreload 2

# Custom BART Model

In [3]:
# TODO: set those args in a config file
OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
BOS_TOKEN_ID = 16384
BASE_MODEL = 'facebook/bart-large-cnn'

In [4]:
import jax
import flax.linen as nn

from transformers.models.bart.modeling_flax_bart import *
from transformers import BartTokenizer, FlaxBartForConditionalGeneration

class CustomFlaxBartModule(FlaxBartModule):
 def setup(self):
 # we keep shared to easily load pre-trained weights
 self.shared = nn.Embed(
 self.config.vocab_size,
 self.config.d_model,
 embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
 dtype=self.dtype,
 )
 # a separate embedding is used for the decoder
 self.decoder_embed = nn.Embed(
 OUTPUT_VOCAB_SIZE,
 self.config.d_model,
 embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
 dtype=self.dtype,
 )
 self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)

 # the decoder has a different config
 decoder_config = BartConfig(self.config.to_dict())
 decoder_config.max_position_embeddings = OUTPUT_LENGTH
 decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
 self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)

class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
 def setup(self):
 self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
 self.lm_head = nn.Dense(
 OUTPUT_VOCAB_SIZE,
 use_bias=False,
 dtype=self.dtype,
 kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
 )
 self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, OUTPUT_VOCAB_SIZE))

class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
 module_class = CustomFlaxBartForConditionalGenerationModule

In [5]:
# load pre-trained model for encoder weights
base_model = FlaxBartForConditionalGeneration.from_pretrained(BASE_MODEL)



In [6]:
# set up our new model config
config = BartConfig.from_pretrained(BASE_MODEL)
config.tie_word_embeddings = False
config.decoder_start_token_id = BOS_TOKEN_ID
config.bos_token_id = BOS_TOKEN_ID # should not be used
config.pos_token_id = BOS_TOKEN_ID # should not be used
#config.eos_token_id = None # prevents generation from stopping until we reach max_length

In [7]:
# create our model and initialize it randomly
model = CustomFlaxBartForConditionalGeneration(config)

In [8]:
# use pretrained weights
model.params['model']['encoder'] = base_model.params['model']['encoder']
model.params['model']['shared'] = base_model.params['model']['shared']

In [9]:
# no need for base_model anymore
del base_model

In [10]:
# we verify that the shape has not been modified
model.params['final_logits_bias'].shape

(1, 16385)

## Inference

In [11]:
tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)

In [12]:
text = "My friends are cool but they eat too many carbs."
inputs = tokenizer(text, max_length=1024, return_tensors='jax')
encoder_outputs = model.encode(**inputs)

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


In [13]:
decoder_start_token_id = model.config.decoder_start_token_id
decoder_start_token_id

16384

In [14]:
decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
outputs = model.decode(decoder_input_ids, encoder_outputs)

In [15]:
outputs

FlaxCausalLMOutputWithCrossAttentions([('logits',
 DeviceArray([[[ 0.5263986 , -2.0947676 , -0.18830685, ..., 0.7599884 ,
 0.6746795 , -1.0411576 ]]], dtype=float32))])

In [16]:
outputs.logits.shape

(1, 1, 16385)

In [17]:
outputs.logits.argmax(axis=-1)

DeviceArray([[12459]], dtype=int32)

In [18]:
model.config.bos_token_id, model.config.eos_token_id, model.config.pad_token_id

(16384, 2, 1)

In [19]:
input_ids_test = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')

In [20]:
greedy_output = model.generate(input_ids_test, max_length=50)

In [21]:
greedy_output[0]

DeviceArray([[16384, 0, 3570, 13405, 10186, 2392, 16362, 1869,
 15772, 13546, 15772, 13546, 9348, 14791, 15772, 15772,
 15772, 11272, 15772, 13546, 15772, 15772, 13546, 15772,
 13546, 15772, 6642, 15772, 10776, 6431, 15772, 14567,
 13406, 15772, 14567, 6235, 15772, 4909, 16160, 568,
 4664, 6650, 8952, 9089, 15772, 5952, 7375, 10843,
 8952, 2]], dtype=int32)