import flax.linen as nn import jax from transformers import BartConfig from transformers.models.bart.modeling_flax_bart import ( FlaxBartDecoder, FlaxBartEncoder, FlaxBartForConditionalGeneration, FlaxBartForConditionalGenerationModule, FlaxBartModule, ) class CustomFlaxBartModule(FlaxBartModule): def setup(self): # we keep shared to easily load pre-trained weights self.shared = nn.Embed( self.config.vocab_size, self.config.d_model, embedding_init=jax.nn.initializers.normal(self.config.init_std), ) # a separate embedding is used for the decoder self.decoder_embed = nn.Embed( self.config.image_vocab_size + 1, self.config.d_model, embedding_init=jax.nn.initializers.normal(self.config.init_std), ) self.encoder = FlaxBartEncoder( self.config, dtype=self.dtype, embed_tokens=self.shared ) # the decoder has a different config # TODO: should not be needed once we have custom config/module decoder_config = BartConfig(self.config.to_dict()) decoder_config.max_position_embeddings = ( self.config.image_length + 1 # image tokens + BOS ) decoder_config.vocab_size = self.config.image_vocab_size + 1 self.decoder = FlaxBartDecoder( decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed ) class CustomFlaxBartForConditionalGenerationModule( FlaxBartForConditionalGenerationModule ): def setup(self): # set default config self.config.normalize_text = getattr(self.config, "normalize_text", False) self.config.image_length = getattr(self.config, "image_length", 256) self.config.image_vocab_size = getattr(self.config, "image_vocab_size", 16384) self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype) self.lm_head = nn.Dense( self.config.image_vocab_size + 1, # encoded image token space + 1 for bos use_bias=False, kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.final_logits_bias = self.param( "final_logits_bias", self.bias_init, (1, self.config.image_vocab_size + 1) ) class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration): module_class = CustomFlaxBartForConditionalGenerationModule