|
|
|
import jax |
|
import flax.linen as nn |
|
|
|
from transformers.models.bart.modeling_flax_bart import ( |
|
FlaxBartModule, |
|
FlaxBartForConditionalGenerationModule, |
|
FlaxBartForConditionalGeneration, |
|
FlaxBartEncoder, |
|
FlaxBartDecoder |
|
) |
|
|
|
from transformers import BartConfig |
|
|
|
|
|
|
|
OUTPUT_VOCAB_SIZE = 16384 + 1 |
|
OUTPUT_LENGTH = 256 + 1 |
|
BOS_TOKEN_ID = 16384 |
|
BASE_MODEL = 'facebook/bart-large-cnn' |
|
|
|
|
|
class CustomFlaxBartModule(FlaxBartModule): |
|
def setup(self): |
|
|
|
self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE) |
|
self.config.max_position_embeddings_decoder = getattr(self.config, 'max_position_embeddings_decoder', OUTPUT_LENGTH) |
|
|
|
|
|
self.shared = nn.Embed( |
|
self.config.vocab_size, |
|
self.config.d_model, |
|
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), |
|
dtype=self.dtype, |
|
) |
|
|
|
self.decoder_embed = nn.Embed( |
|
self.config.vocab_size_output, |
|
self.config.d_model, |
|
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), |
|
dtype=self.dtype, |
|
) |
|
self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) |
|
|
|
|
|
decoder_config = BartConfig(self.config.to_dict()) |
|
decoder_config.max_position_embeddings = self.config.max_position_embeddings_decoder |
|
decoder_config.vocab_size = self.config.vocab_size_output |
|
self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed) |
|
|
|
class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule): |
|
def setup(self): |
|
|
|
self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE) |
|
|
|
self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype) |
|
self.lm_head = nn.Dense( |
|
self.config.vocab_size_output, |
|
use_bias=False, |
|
dtype=self.dtype, |
|
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), |
|
) |
|
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.config.vocab_size_output)) |
|
|
|
class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration): |
|
module_class = CustomFlaxBartForConditionalGenerationModule |
|
|