from typing import Optional, Tuple import jax import jax.numpy as jnp from jax.random import PRNGKey import flax.linen as nn from flax.core.frozen_dict import FrozenDict, unfreeze from transformers.modeling_flax_outputs import FlaxCausalLMOutputWithCrossAttentions from transformers.file_utils import add_start_docstrings from transformers.modeling_flax_utils import FlaxPreTrainedModel from transformers.models.t5.modeling_flax_t5 import FlaxT5ForConditionalGenerationModule from t5_vae_flax_alt.src.vae import VAE from t5_vae_flax_alt.src.generate import VaeFlaxGenerationMixin from t5_vae_flax_alt.src.outputs import TransformerVaeOutput from t5_vae_flax_alt.src.config import T5VaeConfig @add_start_docstrings("""T5 Model with a `language modeling` head on top converted into a VAE.""") class FlaxT5VaeForAutoencodingModule(nn.Module): config: T5VaeConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def _get_encoder_module(self): return self.t5.encoder def _get_vae_encoder_module(self): return self.vae.encoder def _get_vae_decoder_module(self): return self.vae.decoder def _get_decoder_module(self): return self.t5.decoder def setup(self): self.t5 = FlaxT5ForConditionalGenerationModule(self.config.t5) self.vae = VAE(self.config) def __call__( self, input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_outputs=None, latent_codes=None, output_attentions=None, output_hidden_states=None, return_dict=None, deterministic: bool = True, ): """ Adapted from `FlaxT5ForConditionalGenerationModule` """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Encode encoder_outputs = self.t5.encoder( input_ids=input_ids, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=deterministic, ) hidden_states = encoder_outputs[0] # Autoencode hidden_states, latent_codes = self.vae(hidden_states, latent_codes) encoder_attention_mask = jnp.ones((hidden_states.shape[0], hidden_states.shape[1])) # Decode decoder_outputs = self.t5.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=deterministic, ) sequence_output = decoder_outputs[0] if self.config.tie_word_embeddings: # Rescale output before projecting on vocab # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 sequence_output = sequence_output * (self.config.t5.d_model ** -0.5) if self.t5.config.tie_word_embeddings: shared_embedding = self.t5.shared.variables["params"]["embedding"] lm_logits = self.t5.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) else: lm_logits = self.t5.lm_head(sequence_output) if not return_dict: return [lm_logits, latent_codes] + decoder_outputs[1:] + encoder_outputs return TransformerVaeOutput( logits=lm_logits, latent_codes=latent_codes, last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, ) class FlaxT5VaePreTrainedModel(FlaxPreTrainedModel, VaeFlaxGenerationMixin): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = T5VaeConfig base_model_prefix = "transformer" module_class: nn.Module = None def __init__( self, config: T5VaeConfig, input_shape: Tuple[int] = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs ): module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) decoder_input_ids = jnp.ones_like(input_ids) decoder_attention_mask = jnp.ones_like(input_ids) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} return self.module.init( rngs, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, )["params"] def __call__( self, input_ids: jnp.ndarray, attention_mask: Optional[jnp.ndarray] = None, decoder_input_ids: jnp.ndarray = None, decoder_attention_mask: Optional[jnp.ndarray] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.return_dict if decoder_input_ids is None: raise ValueError( "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed here." ) # prepare encoder inputs if attention_mask is None: attention_mask = jnp.ones_like(input_ids) # prepare decoder inputs if decoder_attention_mask is None: decoder_attention_mask = jnp.ones_like(decoder_input_ids) # Handle any PRNG if needed rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} return self.module.apply( {"params": params or self.params}, input_ids=jnp.array(input_ids, dtype="i4"), attention_mask=jnp.array(attention_mask, dtype="i4"), decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, rngs=rngs, ) def init_cache(self, batch_size, max_length, latent_codes): r""" Args: batch_size (:obj:`int`): batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. max_length (:obj:`int`): maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized cache. latent_codes (:obj:`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): ``latent_codes`` consists of compressed hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. """ # init input variables to retrieve cache decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") decoder_attention_mask = jnp.ones_like(decoder_input_ids) def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): decoder_module = module._get_decoder_module() return decoder_module( decoder_input_ids, decoder_attention_mask, **kwargs, ) init_variables = self.module.init( jax.random.PRNGKey(0), decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, init_cache=True, method=_decoder_forward, # we only need to call the decoder to init the cache ) return unfreeze(init_variables["cache"]) def encode( self, input_ids: jnp.ndarray, attention_mask: Optional[jnp.ndarray] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): raise NotImplementedError() def decode( self, decoder_input_ids, latent_codes, encoder_attention_mask: Optional[jnp.ndarray] = None, decoder_attention_mask: Optional[jnp.ndarray] = None, past_key_values: dict = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): raise NotImplementedError() class FlaxT5VaeForAutoencoding(FlaxT5VaePreTrainedModel): module_class = FlaxT5VaeForAutoencodingModule def __call__( self, input_ids: jnp.ndarray, attention_mask: Optional[jnp.ndarray] = None, decoder_input_ids=None, decoder_attention_mask=None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): ''' Adapted from `FlaxT5PreTrainedModel` ''' output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.return_dict if decoder_input_ids is None: raise ValueError( "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed here." ) # prepare encoder inputs if attention_mask is None: attention_mask = jnp.ones_like(input_ids) # prepare decoder inputs if decoder_attention_mask is None: decoder_attention_mask = jnp.ones_like(decoder_input_ids) # Handle any PRNG if needed rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} return self.module.apply( {"params": params or self.params}, input_ids=jnp.array(input_ids, dtype="i4"), attention_mask=jnp.array(attention_mask, dtype="i4"), decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, rngs=rngs, ) def encode( self, input_ids: jnp.ndarray, attention_mask: Optional[jnp.ndarray] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.return_dict if attention_mask is None: attention_mask = jnp.ones_like(input_ids) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng def _encoder_forward(module, input_ids, attention_mask, **kwargs): encode_module = module._get_encoder_module() vae_encoder_module = module._get_vae_encoder_module() return vae_encoder_module(encode_module(input_ids, attention_mask, **kwargs)[0]) return self.module.apply( {"params": params or self.params}, input_ids=jnp.array(input_ids, dtype="i4"), attention_mask=jnp.array(attention_mask, dtype="i4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, rngs=rngs, method=_encoder_forward, ) def decode( self, decoder_input_ids, latent_codes, encoder_attention_mask: Optional[jnp.ndarray] = None, decoder_attention_mask: Optional[jnp.ndarray] = None, past_key_values: dict = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): r""" Returns: Example:: >>> model = FlaxT5VaeForAutoencoding.from_pretrained('t5-small') >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') >>> text = "My friends are cool but they eat too many carbs." >>> inputs = tokenizer(text, max_length=512, return_tensors='jax') >>> latent_codes = model.encode(**inputs) >>> decoder_start_token_id = model.config.decoder_start_token_id >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id >>> outputs = model.decode(decoder_input_ids, latent_codes) >>> last_decoder_hidden_states = outputs.last_hidden_state """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.return_dict if encoder_attention_mask is None: batch_size, sequence_length = latent_codes.shape[:2] encoder_attention_mask = jnp.ones((batch_size, sequence_length)) batch_size, sequence_length = decoder_input_ids.shape if decoder_attention_mask is None: decoder_attention_mask = jnp.ones((batch_size, sequence_length)) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng inputs = {"params": params or self.params} # if past_key_values are passed then cache is already initialized a private flag init_cache has to be # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that # it can be changed by FlaxT5Attention module if past_key_values: inputs["cache"] = past_key_values mutable = ["cache"] else: mutable = False def _decoder_forward(module, decoder_input_ids, latent_codes, decoder_attention_mask, **kwargs): vae_decoder_module = module._get_vae_decoder_module() decoder_module = module._get_decoder_module() decoder_outputs = decoder_module( decoder_input_ids, decoder_attention_mask, encoder_hidden_states=vae_decoder_module(latent_codes), **kwargs, ) sequence_output = decoder_outputs[0] if self.config.tie_word_embeddings: # Rescale output before projecting on vocab # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 sequence_output = sequence_output * (self.config.t5.d_model ** -0.5) if self.config.tie_word_embeddings: shared_embedding = module.t5.shared.variables["params"]["embedding"] lm_logits = module.t5.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) else: lm_logits = module.t5.lm_head(sequence_output) return lm_logits, decoder_outputs outputs = self.module.apply( inputs, decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), latent_codes=latent_codes, decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, rngs=rngs, mutable=mutable, method=_decoder_forward, ) if past_key_values is None: lm_logits, decoder_outputs = outputs else: (lm_logits, decoder_outputs), past = outputs if return_dict: outputs = FlaxCausalLMOutputWithCrossAttentions( logits=lm_logits, hidden_states=decoder_outputs.hidden_states, attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, ) else: outputs = (lm_logits,) + decoder_outputs[1:] # add updated cache to model output if past_key_values is not None and return_dict: outputs["past_key_values"] = unfreeze(past["cache"]) return outputs elif past_key_values is not None and not return_dict: outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] return outputs def prepare_inputs_for_generation( self, decoder_input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None, decoder_attention_mask: Optional[jnp.DeviceArray] = None, latent_codes=None, **kwargs ): # initializing the cache batch_size, seq_length = decoder_input_ids.shape past_key_values = self.init_cache(batch_size, max_length, latent_codes) # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. # But since the decoder uses a causal mask, those positions are masked anyways. # Thus we can create a single static attention_mask here, which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if decoder_attention_mask is not None: extended_attention_mask = jax.lax.dynamic_update_slice( extended_attention_mask, decoder_attention_mask, (0, 0) ) return { "past_key_values": past_key_values, "latent_codes": latent_codes, "encoder_attention_mask": attention_mask, "decoder_attention_mask": extended_attention_mask, } def update_inputs_for_generation(self, model_outputs, model_kwargs): model_kwargs["past_key_values"] = model_outputs.past_key_values return model_kwargs