# Installation

In [None]:
!pip install git+https://github.com/huggingface/transformers/
!pip install git+https://github.com/google/flax

In [None]:
%load_ext autoreload
%autoreload 2

# Custom BART Model

In [None]:
# TODO: set those args in a config file
OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
BOS_TOKEN_ID = 16384
BASE_MODEL = 'facebook/bart-large'

In [None]:
import jax
import flax.linen as nn

from transformers.models.bart.modeling_flax_bart import *
from transformers import BartTokenizer, FlaxBartForConditionalGeneration

class CustomFlaxBartModule(FlaxBartModule):
 def setup(self):
 # we keep shared to easily load pre-trained weights
 self.shared = nn.Embed(
 self.config.vocab_size,
 self.config.d_model,
 embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
 dtype=self.dtype,
 )
 # a separate embedding is used for the decoder
 self.decoder_embed = nn.Embed(
 OUTPUT_VOCAB_SIZE,
 self.config.d_model,
 embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
 dtype=self.dtype,
 )
 self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)

 # the decoder has a different config
 decoder_config = BartConfig(self.config.to_dict())
 decoder_config.max_position_embeddings = OUTPUT_LENGTH
 decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
 self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)

class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
 def setup(self):
 self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
 self.lm_head = nn.Dense(
 OUTPUT_VOCAB_SIZE,
 use_bias=False,
 dtype=self.dtype,
 kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
 )
 self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, OUTPUT_VOCAB_SIZE))

class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
 module_class = CustomFlaxBartForConditionalGenerationModule

In [None]:
# load pre-trained model for encoder weights
base_model = FlaxBartForConditionalGeneration.from_pretrained(BASE_MODEL)

In [None]:
# set up our new model config
config = BartConfig.from_pretrained(BASE_MODEL)
config.tie_word_embeddings = False
config.decoder_start_token_id = BOS_TOKEN_ID
config.bos_token_id = BOS_TOKEN_ID # should not be used
config.pos_token_id = BOS_TOKEN_ID # should not be used
#config.eos_token_id = None # prevents generation from stopping until we reach max_length

In [None]:
# create our model and initialize it randomly
model = CustomFlaxBartForConditionalGeneration(config)

In [None]:
# use pretrained weights
model.params['model']['encoder'] = base_model.params['model']['encoder']
model.params['model']['shared'] = base_model.params['model']['shared']

In [None]:
# no need for base_model anymore
del base_model

In [None]:
# we verify that the shape has not been modified
model.params['final_logits_bias'].shape

## Inference

In [None]:
tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)

In [None]:
text = "My friends are cool but they eat too many carbs."
inputs = tokenizer(text, max_length=1024, return_tensors='jax')
encoder_outputs = model.encode(**inputs)

In [None]:
decoder_start_token_id = model.config.decoder_start_token_id
decoder_start_token_id

In [None]:
decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
outputs = model.decode(decoder_input_ids, encoder_outputs)

In [None]:
outputs

In [None]:
outputs.logits.shape

In [None]:
outputs.logits.argmax(axis=-1)

In [None]:
model.config.bos_token_id, model.config.eos_token_id, model.config.pad_token_id

In [None]:
input_ids_test = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')

In [None]:
greedy_output = model.generate(input_ids_test, max_length=50)

In [None]:
greedy_output[0]