Spaces:
Runtime error
Runtime error
from typing import Optional, Tuple | |
import jax | |
import jax.numpy as jnp | |
from jax.random import PRNGKey | |
import flax.linen as nn | |
from flax.core.frozen_dict import FrozenDict, unfreeze | |
from transformers.modeling_flax_outputs import FlaxCausalLMOutputWithCrossAttentions | |
from transformers.file_utils import add_start_docstrings | |
from transformers.modeling_flax_utils import FlaxPreTrainedModel | |
from transformers.models.t5.modeling_flax_t5 import FlaxT5ForConditionalGenerationModule | |
from t5_vae_flax_alt.src.vae import VAE | |
from t5_vae_flax_alt.src.generate import VaeFlaxGenerationMixin | |
from t5_vae_flax_alt.src.outputs import TransformerVaeOutput | |
from t5_vae_flax_alt.src.config import T5VaeConfig | |
class FlaxT5VaeForAutoencodingModule(nn.Module): | |
config: T5VaeConfig | |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation | |
def _get_encoder_module(self): | |
return self.t5.encoder | |
def _get_vae_encoder_module(self): | |
return self.vae.encoder | |
def _get_vae_decoder_module(self): | |
return self.vae.decoder | |
def _get_decoder_module(self): | |
return self.t5.decoder | |
def setup(self): | |
self.t5 = FlaxT5ForConditionalGenerationModule(self.config.t5) | |
self.vae = VAE(self.config) | |
def __call__( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
decoder_input_ids=None, | |
decoder_attention_mask=None, | |
encoder_outputs=None, | |
latent_codes=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
deterministic: bool = True, | |
): | |
""" | |
Adapted from `FlaxT5ForConditionalGenerationModule` | |
""" | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
# Encode | |
encoder_outputs = self.t5.encoder( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
deterministic=deterministic, | |
) | |
hidden_states = encoder_outputs[0] | |
# Autoencode | |
hidden_states, latent_codes = self.vae(hidden_states, latent_codes) | |
encoder_attention_mask = jnp.ones((hidden_states.shape[0], hidden_states.shape[1])) | |
# Decode | |
decoder_outputs = self.t5.decoder( | |
input_ids=decoder_input_ids, | |
attention_mask=decoder_attention_mask, | |
encoder_hidden_states=hidden_states, | |
encoder_attention_mask=encoder_attention_mask, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
deterministic=deterministic, | |
) | |
sequence_output = decoder_outputs[0] | |
if self.config.tie_word_embeddings: | |
# Rescale output before projecting on vocab | |
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 | |
sequence_output = sequence_output * (self.config.t5.d_model ** -0.5) | |
if self.t5.config.tie_word_embeddings: | |
shared_embedding = self.t5.shared.variables["params"]["embedding"] | |
lm_logits = self.t5.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) | |
else: | |
lm_logits = self.t5.lm_head(sequence_output) | |
if not return_dict: | |
return [lm_logits, latent_codes] + decoder_outputs[1:] + encoder_outputs | |
return TransformerVaeOutput( | |
logits=lm_logits, | |
latent_codes=latent_codes, | |
last_hidden_state=decoder_outputs.last_hidden_state, | |
past_key_values=decoder_outputs.past_key_values, | |
decoder_hidden_states=decoder_outputs.hidden_states, | |
decoder_attentions=decoder_outputs.attentions, | |
cross_attentions=decoder_outputs.cross_attentions, | |
encoder_last_hidden_state=encoder_outputs.last_hidden_state, | |
encoder_hidden_states=encoder_outputs.hidden_states, | |
encoder_attentions=encoder_outputs.attentions, | |
) | |
class FlaxT5VaePreTrainedModel(FlaxPreTrainedModel, VaeFlaxGenerationMixin): | |
""" | |
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained | |
models. | |
""" | |
config_class = T5VaeConfig | |
base_model_prefix = "transformer" | |
module_class: nn.Module = None | |
def __init__( | |
self, | |
config: T5VaeConfig, | |
input_shape: Tuple[int] = (1, 1), | |
seed: int = 0, | |
dtype: jnp.dtype = jnp.float32, | |
**kwargs | |
): | |
module = self.module_class(config=config, dtype=dtype, **kwargs) | |
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) | |
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: | |
# init input tensors | |
input_ids = jnp.zeros(input_shape, dtype="i4") | |
attention_mask = jnp.ones_like(input_ids) | |
decoder_input_ids = jnp.ones_like(input_ids) | |
decoder_attention_mask = jnp.ones_like(input_ids) | |
params_rng, dropout_rng = jax.random.split(rng) | |
rngs = {"params": params_rng, "dropout": dropout_rng} | |
return self.module.init( | |
rngs, | |
input_ids, | |
attention_mask, | |
decoder_input_ids, | |
decoder_attention_mask, | |
)["params"] | |
def __call__( | |
self, | |
input_ids: jnp.ndarray, | |
attention_mask: Optional[jnp.ndarray] = None, | |
decoder_input_ids: jnp.ndarray = None, | |
decoder_attention_mask: Optional[jnp.ndarray] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
train: bool = False, | |
params: dict = None, | |
dropout_rng: PRNGKey = None, | |
): | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.return_dict | |
if decoder_input_ids is None: | |
raise ValueError( | |
"Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed here." | |
) | |
# prepare encoder inputs | |
if attention_mask is None: | |
attention_mask = jnp.ones_like(input_ids) | |
# prepare decoder inputs | |
if decoder_attention_mask is None: | |
decoder_attention_mask = jnp.ones_like(decoder_input_ids) | |
# Handle any PRNG if needed | |
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} | |
return self.module.apply( | |
{"params": params or self.params}, | |
input_ids=jnp.array(input_ids, dtype="i4"), | |
attention_mask=jnp.array(attention_mask, dtype="i4"), | |
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), | |
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
deterministic=not train, | |
rngs=rngs, | |
) | |
def init_cache(self, batch_size, max_length, latent_codes): | |
r""" | |
Args: | |
batch_size (:obj:`int`): | |
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. | |
max_length (:obj:`int`): | |
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized | |
cache. | |
latent_codes (:obj:`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): | |
``latent_codes`` consists of compressed hidden-states at the output of the last layer of the encoder. | |
Used in the cross-attention of the decoder. | |
""" | |
# init input variables to retrieve cache | |
decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") | |
decoder_attention_mask = jnp.ones_like(decoder_input_ids) | |
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): | |
decoder_module = module._get_decoder_module() | |
return decoder_module( | |
decoder_input_ids, | |
decoder_attention_mask, | |
**kwargs, | |
) | |
init_variables = self.module.init( | |
jax.random.PRNGKey(0), | |
decoder_input_ids=decoder_input_ids, | |
decoder_attention_mask=decoder_attention_mask, | |
init_cache=True, | |
method=_decoder_forward, # we only need to call the decoder to init the cache | |
) | |
return unfreeze(init_variables["cache"]) | |
def encode( | |
self, | |
input_ids: jnp.ndarray, | |
attention_mask: Optional[jnp.ndarray] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
train: bool = False, | |
params: dict = None, | |
dropout_rng: PRNGKey = None, | |
): | |
raise NotImplementedError() | |
def decode( | |
self, | |
decoder_input_ids, | |
latent_codes, | |
encoder_attention_mask: Optional[jnp.ndarray] = None, | |
decoder_attention_mask: Optional[jnp.ndarray] = None, | |
past_key_values: dict = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
train: bool = False, | |
params: dict = None, | |
dropout_rng: PRNGKey = None, | |
): | |
raise NotImplementedError() | |
class FlaxT5VaeForAutoencoding(FlaxT5VaePreTrainedModel): | |
module_class = FlaxT5VaeForAutoencodingModule | |
def __call__( | |
self, | |
input_ids: jnp.ndarray, | |
attention_mask: Optional[jnp.ndarray] = None, | |
decoder_input_ids=None, | |
decoder_attention_mask=None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
train: bool = False, | |
params: dict = None, | |
dropout_rng: PRNGKey = None, | |
): | |
''' | |
Adapted from `FlaxT5PreTrainedModel` | |
''' | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.return_dict | |
if decoder_input_ids is None: | |
raise ValueError( | |
"Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed here." | |
) | |
# prepare encoder inputs | |
if attention_mask is None: | |
attention_mask = jnp.ones_like(input_ids) | |
# prepare decoder inputs | |
if decoder_attention_mask is None: | |
decoder_attention_mask = jnp.ones_like(decoder_input_ids) | |
# Handle any PRNG if needed | |
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} | |
return self.module.apply( | |
{"params": params or self.params}, | |
input_ids=jnp.array(input_ids, dtype="i4"), | |
attention_mask=jnp.array(attention_mask, dtype="i4"), | |
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), | |
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
deterministic=not train, | |
rngs=rngs, | |
) | |
def encode( | |
self, | |
input_ids: jnp.ndarray, | |
attention_mask: Optional[jnp.ndarray] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
train: bool = False, | |
params: dict = None, | |
dropout_rng: PRNGKey = None, | |
): | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.return_dict | |
if attention_mask is None: | |
attention_mask = jnp.ones_like(input_ids) | |
# Handle any PRNG if needed | |
rngs = {} | |
if dropout_rng is not None: | |
rngs["dropout"] = dropout_rng | |
def _encoder_forward(module, input_ids, attention_mask, **kwargs): | |
encode_module = module._get_encoder_module() | |
vae_encoder_module = module._get_vae_encoder_module() | |
return vae_encoder_module(encode_module(input_ids, attention_mask, **kwargs)[0]) | |
return self.module.apply( | |
{"params": params or self.params}, | |
input_ids=jnp.array(input_ids, dtype="i4"), | |
attention_mask=jnp.array(attention_mask, dtype="i4"), | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
deterministic=not train, | |
rngs=rngs, | |
method=_encoder_forward, | |
) | |
def decode( | |
self, | |
decoder_input_ids, | |
latent_codes, | |
encoder_attention_mask: Optional[jnp.ndarray] = None, | |
decoder_attention_mask: Optional[jnp.ndarray] = None, | |
past_key_values: dict = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
train: bool = False, | |
params: dict = None, | |
dropout_rng: PRNGKey = None, | |
): | |
r""" | |
Returns: | |
Example:: | |
>>> model = FlaxT5VaeForAutoencoding.from_pretrained('t5-small') | |
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small') | |
>>> text = "My friends are cool but they eat too many carbs." | |
>>> inputs = tokenizer(text, max_length=512, return_tensors='jax') | |
>>> latent_codes = model.encode(**inputs) | |
>>> decoder_start_token_id = model.config.decoder_start_token_id | |
>>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id | |
>>> outputs = model.decode(decoder_input_ids, latent_codes) | |
>>> last_decoder_hidden_states = outputs.last_hidden_state | |
""" | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.return_dict | |
if encoder_attention_mask is None: | |
batch_size, sequence_length = latent_codes.shape[:2] | |
encoder_attention_mask = jnp.ones((batch_size, sequence_length)) | |
batch_size, sequence_length = decoder_input_ids.shape | |
if decoder_attention_mask is None: | |
decoder_attention_mask = jnp.ones((batch_size, sequence_length)) | |
# Handle any PRNG if needed | |
rngs = {} | |
if dropout_rng is not None: | |
rngs["dropout"] = dropout_rng | |
inputs = {"params": params or self.params} | |
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be | |
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that | |
# it can be changed by FlaxT5Attention module | |
if past_key_values: | |
inputs["cache"] = past_key_values | |
mutable = ["cache"] | |
else: | |
mutable = False | |
def _decoder_forward(module, decoder_input_ids, latent_codes, decoder_attention_mask, **kwargs): | |
vae_decoder_module = module._get_vae_decoder_module() | |
decoder_module = module._get_decoder_module() | |
decoder_outputs = decoder_module( | |
decoder_input_ids, | |
decoder_attention_mask, | |
encoder_hidden_states=vae_decoder_module(latent_codes), | |
**kwargs, | |
) | |
sequence_output = decoder_outputs[0] | |
if self.config.tie_word_embeddings: | |
# Rescale output before projecting on vocab | |
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 | |
sequence_output = sequence_output * (self.config.t5.d_model ** -0.5) | |
if self.config.tie_word_embeddings: | |
shared_embedding = module.t5.shared.variables["params"]["embedding"] | |
lm_logits = module.t5.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) | |
else: | |
lm_logits = module.t5.lm_head(sequence_output) | |
return lm_logits, decoder_outputs | |
outputs = self.module.apply( | |
inputs, | |
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), | |
latent_codes=latent_codes, | |
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), | |
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
deterministic=not train, | |
rngs=rngs, | |
mutable=mutable, | |
method=_decoder_forward, | |
) | |
if past_key_values is None: | |
lm_logits, decoder_outputs = outputs | |
else: | |
(lm_logits, decoder_outputs), past = outputs | |
if return_dict: | |
outputs = FlaxCausalLMOutputWithCrossAttentions( | |
logits=lm_logits, | |
hidden_states=decoder_outputs.hidden_states, | |
attentions=decoder_outputs.attentions, | |
cross_attentions=decoder_outputs.cross_attentions, | |
) | |
else: | |
outputs = (lm_logits,) + decoder_outputs[1:] | |
# add updated cache to model output | |
if past_key_values is not None and return_dict: | |
outputs["past_key_values"] = unfreeze(past["cache"]) | |
return outputs | |
elif past_key_values is not None and not return_dict: | |
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] | |
return outputs | |
def prepare_inputs_for_generation( | |
self, | |
decoder_input_ids, | |
max_length, | |
attention_mask: Optional[jnp.DeviceArray] = None, | |
decoder_attention_mask: Optional[jnp.DeviceArray] = None, | |
latent_codes=None, | |
**kwargs | |
): | |
# initializing the cache | |
batch_size, seq_length = decoder_input_ids.shape | |
past_key_values = self.init_cache(batch_size, max_length, latent_codes) | |
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. | |
# But since the decoder uses a causal mask, those positions are masked anyways. | |
# Thus we can create a single static attention_mask here, which is more efficient for compilation | |
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") | |
if decoder_attention_mask is not None: | |
extended_attention_mask = jax.lax.dynamic_update_slice( | |
extended_attention_mask, decoder_attention_mask, (0, 0) | |
) | |
return { | |
"past_key_values": past_key_values, | |
"latent_codes": latent_codes, | |
"encoder_attention_mask": attention_mask, | |
"decoder_attention_mask": extended_attention_mask, | |
} | |
def update_inputs_for_generation(self, model_outputs, model_kwargs): | |
model_kwargs["past_key_values"] = model_outputs.past_key_values | |
return model_kwargs | |