Fraser commited on
Commit
1d30073
1 Parent(s): cecc83b

cope that submodules not allowed

Browse files
app.py CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
2
  import jax.numpy as jnp
3
  from transformers import AutoTokenizer
4
  from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
5
- from t5_vae_flax.src.t5_vae import FlaxT5VaeForAutoencoding
6
 
7
 
8
  st.title('T5-VAE')
2
  import jax.numpy as jnp
3
  from transformers import AutoTokenizer
4
  from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
5
+ from t5_vae_flax_alt.src.t5_vae import FlaxT5VaeForAutoencoding
6
 
7
 
8
  st.title('T5-VAE')
t5_vae_flax_alt/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ *.pyc
2
+ venv
3
+ .vscode
t5_vae_flax_alt/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ # t5-vae-flax
2
+
3
+ Model code for running a T5-VAE with flax.
t5_vae_flax_alt/__init__.py ADDED
File without changes
t5_vae_flax_alt/src/__init__.py ADDED
File without changes
t5_vae_flax_alt/src/config.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from transformers.utils import logging
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers import AutoConfig, T5Config
5
+
6
+ from t5_vae_flax_alt.src.encoders import VAE_ENCODER_MODELS
7
+ from t5_vae_flax_alt.src.decoders import VAE_DECODER_MODELS
8
+ from t5_vae_flax_alt.src.utils import assertEqual, assertIn
9
+
10
+ logger = logging.get_logger(__name__)
11
+
12
+
13
+ class T5VaeConfig(PretrainedConfig):
14
+ r"""
15
+ This is the configuration class to store the configuration of :class:`FlaxT5VAE`.
16
+ It is used to instantiate a T5-VAE model according to the specified arguments, defining the model architecture.
17
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the T5 `t5-vae-base architecture.
18
+
19
+ To be able to use `transformer.trainer.Trainer` we need some specific training logic & config in the model.
20
+
21
+ Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
22
+ outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
23
+
24
+ Arguments:
25
+ n_latent_tokens (:obj:`int`, `optional`, defaults to 6):
26
+ Number of latent tokens (must be less than seq length).
27
+ latent_token_size (:obj:`int`, `optional`, defaults to 32):
28
+ Number of dimensions to use for each latent token.
29
+ t5_name (:obj:`str`, `optional`, defaults to t5-base):
30
+ Name of the Transformer model to use as a decoder.
31
+ block_size (:obj:`int`, `optional`, defaults to 60):
32
+ NOTE: Every input sequence must be padded to be equal to this length.
33
+ """
34
+ model_type = "transformer_vae"
35
+ is_composition = True
36
+
37
+ def __init__(
38
+ self,
39
+ t5_model_name_or_path=None,
40
+ n_latent_tokens=6, # set to -1 for full sequence
41
+ latent_token_size=32,
42
+ vae_encoder_model='',
43
+ vae_decoder_model='',
44
+ block_size=60,
45
+ decoder_start_token_id=0,
46
+ cache_dir=None,
47
+ tie_word_embeddings=True,
48
+ # T5 config
49
+ t5=dict(),
50
+ vocab_size=32128,
51
+ d_model=512,
52
+ d_kv=64,
53
+ d_ff=2048,
54
+ num_layers=6,
55
+ num_decoder_layers=None,
56
+ num_heads=8,
57
+ relative_attention_num_buckets=32,
58
+ dropout_rate=0.1,
59
+ layer_norm_epsilon=1e-6,
60
+ initializer_factor=1.0,
61
+ feed_forward_proj="relu",
62
+ is_encoder_decoder=True,
63
+ use_cache=True,
64
+ pad_token_id=0,
65
+ eos_token_id=1,
66
+ gradient_checkpointing=False,
67
+ # end
68
+ **kwargs,
69
+ ):
70
+ assertIn(vae_encoder_model, VAE_ENCODER_MODELS.keys(), "Unexpected VAE encoder.")
71
+ assertIn(vae_decoder_model, VAE_DECODER_MODELS.keys(), "Unexpected VAE decoder.")
72
+
73
+ super().__init__(**kwargs)
74
+
75
+ self.set_seq_size = block_size
76
+
77
+ # VAE
78
+ self.vae_encoder_model = vae_encoder_model
79
+ self.vae_decoder_model = vae_decoder_model
80
+
81
+ self.latent_token_size = latent_token_size
82
+ assert(n_latent_tokens <= self.set_seq_size, 'Cannot use more latent tokens than input tokens.')
83
+ self.n_latent_tokens = n_latent_tokens
84
+ self.use_cache = use_cache
85
+
86
+ # T5
87
+ if t5_model_name_or_path:
88
+ self.t5 = AutoConfig.from_pretrained(t5_model_name_or_path, cache_dir=cache_dir)
89
+ assertEqual(self.t5.model_type, "t5", "Need t5 model type for transformer_decoder.")
90
+ self.t5.decoder_start_token_id = decoder_start_token_id
91
+ elif t5:
92
+ # use for loading a config
93
+ self.t5 = T5Config(**t5)
94
+ else:
95
+ self.t5 = T5Config(
96
+ vocab_size=vocab_size,
97
+ d_model=d_model,
98
+ d_kv=d_kv,
99
+ d_ff=d_ff,
100
+ num_layers=num_layers,
101
+ num_decoder_layers=num_decoder_layers,
102
+ num_heads=num_heads,
103
+ relative_attention_num_buckets=relative_attention_num_buckets,
104
+ dropout_rate=dropout_rate,
105
+ layer_norm_epsilon=layer_norm_epsilon,
106
+ initializer_factor=initializer_factor,
107
+ feed_forward_proj=feed_forward_proj,
108
+ is_encoder_decoder=is_encoder_decoder,
109
+ use_cache=use_cache,
110
+ pad_token_id=pad_token_id,
111
+ eos_token_id=eos_token_id,
112
+ gradient_checkpointing=gradient_checkpointing,
113
+ **kwargs
114
+ )
115
+
116
+ if self.t5.d_model < self.latent_token_size:
117
+ raise Exception('Using larger latent token dimension then T5 hidden dimension.')
118
+
119
+ # Add t5 config options
120
+ self.tie_word_embeddings = tie_word_embeddings
121
+ self.t5.tie_word_embeddings = self.tie_word_embeddings
122
+ self.t5.use_cache = self.use_cache
123
+ self.pad_token_id = pad_token_id
124
+ self.eos_token_id = eos_token_id
125
+ self.decoder_start_token_id = self.t5.decoder_start_token_id
126
+
127
+ def to_dict(self):
128
+ """
129
+ Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig`.
130
+
131
+ Returns:
132
+ :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
133
+ """
134
+ output = copy.deepcopy(self.__dict__)
135
+ output["model_type"] = self.__class__.model_type
136
+ output['t5'] = self.t5.to_dict()
137
+ return output
t5_vae_flax_alt/src/decoders.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import flax.linen as nn
3
+
4
+ logger = logging.getLogger(__name__)
5
+
6
+
7
+ class Decoder(nn.Module):
8
+ '''
9
+ Converts latent code -> transformer encoding.
10
+ '''
11
+ dim_model: int
12
+ n_latent_tokens: int
13
+
14
+ @nn.compact
15
+ def __call__(self, latent_code): # (batch, latent_tokens_per_sequence, latent_token_dim)
16
+ raw_latent_tokens = nn.Dense(self.dim_model)(latent_code)
17
+ latent_tokens = nn.LayerNorm()(raw_latent_tokens)
18
+ return latent_tokens # (batch, latent_tokens_per_sequence, dim_model)
19
+
20
+
21
+ VAE_DECODER_MODELS = {
22
+ '': Decoder,
23
+ }
t5_vae_flax_alt/src/encoders.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import jax.numpy as jnp
3
+ import flax.linen as nn
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ class Encoder(nn.Module):
9
+ '''
10
+ Converts N hidden tokens into N seperate latent codes.
11
+ '''
12
+ latent_token_size: int
13
+ n_latent_tokens: int
14
+
15
+ @nn.compact
16
+ def __call__(self, encoding):
17
+ latent_tokens = nn.Dense(self.latent_token_size)(encoding)
18
+ raw_latent_code = latent_tokens[:, : self.n_latent_tokens, :]
19
+ # TODO does this just apply tanh to each latent token? Or across the whole batch
20
+ latent_code = jnp.tanh(raw_latent_code)
21
+ return latent_code # (batch, latent_tokens_per_sequence, latent_token_dim)
22
+
23
+
24
+ VAE_ENCODER_MODELS = {
25
+ '': Encoder,
26
+ }
t5_vae_flax_alt/src/generate.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import jaxlib.xla_extension as jax_xla
6
+
7
+ from transformers.generation_flax_utils import FlaxGenerationMixin
8
+ from transformers.utils import logging
9
+
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+
14
+ class VaeFlaxGenerationMixin(FlaxGenerationMixin):
15
+ def generate(
16
+ self,
17
+ latent_codes: jax_xla.DeviceArray,
18
+ max_length: Optional[int] = None,
19
+ pad_token_id: Optional[int] = None,
20
+ bos_token_id: Optional[int] = None,
21
+ eos_token_id: Optional[int] = None,
22
+ decoder_start_token_id: Optional[int] = None,
23
+ do_sample: Optional[bool] = None,
24
+ prng_key: Optional[jax_xla.DeviceArray] = None,
25
+ top_k: Optional[int] = None,
26
+ top_p: Optional[float] = None,
27
+ temperature: Optional[float] = None,
28
+ num_beams: Optional[int] = None,
29
+ no_repeat_ngram_size: Optional[int] = None,
30
+ min_length: Optional[int] = None,
31
+ forced_bos_token_id: Optional[int] = None,
32
+ forced_eos_token_id: Optional[int] = None,
33
+ length_penalty: Optional[float] = None,
34
+ early_stopping: Optional[bool] = None,
35
+ trace: bool = True,
36
+ params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
37
+ **model_kwargs,
38
+ ):
39
+ r"""
40
+ Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
41
+ and, multinomial sampling.
42
+
43
+ Apart from :obj:`latent_codes`, all the arguments below will default to the value of the attribute of the same
44
+ name inside the :class:`~transformers.PretrainedConfig` of the model. The default values indicated are the
45
+ default values of those config.
46
+
47
+ Most of these parameters are explained in more detail in `this blog post
48
+ <https://huggingface.co/blog/how-to-generate>`__.
49
+
50
+ Parameters:
51
+
52
+ latent_codes (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, n_latent_tokens, latent_token_dim)`, `optional`):
53
+ The sequence used as a prompt for the generation.
54
+ max_length (:obj:`int`, `optional`, defaults to 20):
55
+ The maximum length of the sequence to be generated.
56
+ do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`):
57
+ Whether or not to use sampling ; use greedy decoding otherwise.
58
+ temperature (:obj:`float`, `optional`, defaults to 1.0):
59
+ The value used to module the next token probabilities.
60
+ top_k (:obj:`int`, `optional`, defaults to 50):
61
+ The number of highest probability vocabulary tokens to keep for top-k-filtering.
62
+ top_p (:obj:`float`, `optional`, defaults to 1.0):
63
+ If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or
64
+ higher are kept for generation.
65
+ pad_token_id (:obj:`int`, `optional`):
66
+ The id of the `padding` token.
67
+ bos_token_id (:obj:`int`, `optional`):
68
+ The id of the `beginning-of-sequence` token.
69
+ eos_token_id (:obj:`int`, `optional`):
70
+ The id of the `end-of-sequence` token.
71
+ num_beams (:obj:`int`, `optional`, defaults to 1):
72
+ Number of beams for beam search. 1 means no beam search.
73
+ decoder_start_token_id (:obj:`int`, `optional`):
74
+ If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
75
+ trace (:obj:`bool`, `optional`, defaults to :obj:`True`):
76
+ Whether to trace generation. Setting ``trace=False`` should only be used for debugging and will lead to
77
+ a considerably slower runtime.
78
+ params (:obj:`Dict[str, jax_xla.DeviceArray]`, `optional`):
79
+ Optionally the model parameters can be passed. Can be useful for parallelized generation.
80
+ model_kwargs:
81
+ Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
82
+
83
+ Return:
84
+ :class:`~transformers.file_utils.ModelOutput`.
85
+
86
+ Examples::
87
+ >>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
88
+
89
+ >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
90
+ >>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2")
91
+ >>> input_context = "The dog"
92
+ >>> # encode input context
93
+ >>> input_ids = tokenizer(input_context, return_tensors="jax").input_ids
94
+ >>> # generate candidates using sampling
95
+ >>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
96
+ >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
97
+ """
98
+ # set init values
99
+ max_length = max_length if max_length is not None else self.config.max_length
100
+ bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
101
+ pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
102
+ eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
103
+ decoder_start_token_id = (
104
+ decoder_start_token_id if decoder_start_token_id else self.config.decoder_start_token_id
105
+ )
106
+ prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
107
+
108
+ if decoder_start_token_id is None and self.config.is_encoder_decoder:
109
+ raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.")
110
+
111
+ model_kwargs['latent_codes'] = latent_codes
112
+
113
+ if self.config.is_encoder_decoder:
114
+ # add encoder_outputs to model_kwargs
115
+ # NOTE: Don't prepare encoder outputs, instead rely on latent_codes.
116
+ # model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
117
+ # prepare decoder_input_ids for generation
118
+ input_ids = jnp.ones((latent_codes.shape[0], 1), dtype="i4") * decoder_start_token_id
119
+
120
+ do_sample = do_sample if do_sample is not None else self.config.do_sample
121
+ num_beams = num_beams if num_beams is not None else self.config.num_beams
122
+
123
+ if not do_sample and num_beams == 1:
124
+ logits_processor = self._get_logits_processor(
125
+ no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
126
+ )
127
+ return self._greedy_search(
128
+ input_ids,
129
+ max_length,
130
+ pad_token_id,
131
+ eos_token_id,
132
+ logits_processor=logits_processor,
133
+ trace=trace,
134
+ params=params,
135
+ model_kwargs=model_kwargs,
136
+ )
137
+ elif do_sample and num_beams == 1:
138
+ logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature)
139
+ logits_processor = self._get_logits_processor(
140
+ no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
141
+ )
142
+ return self._sample(
143
+ input_ids,
144
+ max_length,
145
+ pad_token_id,
146
+ eos_token_id,
147
+ prng_key,
148
+ logits_warper=logits_warper,
149
+ logits_processor=logits_processor,
150
+ trace=trace,
151
+ params=params,
152
+ model_kwargs=model_kwargs,
153
+ )
154
+ elif not do_sample and num_beams > 1:
155
+ # broadcast input_ids & encoder_outputs
156
+ input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams)
157
+
158
+ if "encoder_outputs" in model_kwargs:
159
+ model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams(
160
+ model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=num_beams
161
+ )
162
+
163
+ if "attention_mask" in model_kwargs:
164
+ model_kwargs["attention_mask"] = self._expand_to_num_beams(
165
+ model_kwargs["attention_mask"], num_beams=num_beams
166
+ )
167
+
168
+ logits_processor = self._get_logits_processor(
169
+ no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
170
+ )
171
+
172
+ return self._beam_search(
173
+ input_ids,
174
+ max_length,
175
+ pad_token_id,
176
+ eos_token_id,
177
+ length_penalty=length_penalty,
178
+ early_stopping=early_stopping,
179
+ logits_processor=logits_processor,
180
+ trace=trace,
181
+ params=params,
182
+ model_kwargs=model_kwargs,
183
+ )
184
+ else:
185
+ raise NotImplementedError("`Beam sampling is currently not implemented.")
t5_vae_flax_alt/src/outputs.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import flax
4
+ import jaxlib.xla_extension as jax_xla
5
+
6
+ from transformers.file_utils import ModelOutput
7
+
8
+
9
+ @flax.struct.dataclass
10
+ class TransformerVaeOutput(ModelOutput):
11
+ """
12
+ Base class for a Transformer-VAE's outputs.
13
+
14
+ Args:
15
+ latent_codes (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_latent_tokens, latent_token_size)`):
16
+ Latent codes representing encoded sequences.
17
+ remade_encoder_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_tokens, model_dim)`):
18
+ Reconstructed encoder hidden states representing sequences.
19
+
20
+ (std Seq2Seq) Args:
21
+ logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
22
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
23
+ past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
24
+ Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2
25
+ tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional
26
+ tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
27
+
28
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
29
+ blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
30
+ last_hidden_state (:obj:`tuple(jax_xla.DeviceArray)`:
31
+ Last model hidden state.
32
+ decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
33
+ Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
34
+ layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
35
+
36
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
37
+ decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
38
+ Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
39
+ sequence_length, sequence_length)`.
40
+
41
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
42
+ self-attention heads.
43
+ cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
44
+ Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
45
+ sequence_length, sequence_length)`.
46
+
47
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
48
+ weighted average in the cross-attention heads.
49
+ encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
50
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
51
+ encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
52
+ Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
53
+ layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
54
+
55
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
56
+ encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
57
+ Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
58
+ sequence_length, sequence_length)`.
59
+
60
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
61
+ self-attention heads.
62
+ """
63
+ logits: jax_xla.DeviceArray = None
64
+ latent_codes: jax_xla.DeviceArray = None
65
+ remade_encoder_hidden_state: jax_xla.DeviceArray = None
66
+ # seq2seq
67
+ past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
68
+ decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
69
+ decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
70
+ cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
71
+ last_hidden_state: Optional[jax_xla.DeviceArray] = None
72
+ encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None
73
+ encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
74
+ encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
t5_vae_flax_alt/src/t5_vae.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from jax.random import PRNGKey
6
+ import flax.linen as nn
7
+ from flax.core.frozen_dict import FrozenDict, unfreeze
8
+
9
+ from transformers.modeling_flax_outputs import FlaxCausalLMOutputWithCrossAttentions
10
+ from transformers.file_utils import add_start_docstrings
11
+ from transformers.modeling_flax_utils import FlaxPreTrainedModel
12
+ from transformers.models.t5.modeling_flax_t5 import FlaxT5ForConditionalGenerationModule
13
+
14
+ from t5_vae_flax_alt.src.vae import VAE
15
+ from t5_vae_flax_alt.src.generate import VaeFlaxGenerationMixin
16
+ from t5_vae_flax_alt.src.outputs import TransformerVaeOutput
17
+ from t5_vae_flax_alt.src.config import T5VaeConfig
18
+
19
+
20
+ @add_start_docstrings("""T5 Model with a `language modeling` head on top converted into a VAE.""")
21
+ class FlaxT5VaeForAutoencodingModule(nn.Module):
22
+ config: T5VaeConfig
23
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
24
+
25
+ def _get_encoder_module(self):
26
+ return self.t5.encoder
27
+
28
+ def _get_vae_encoder_module(self):
29
+ return self.vae.encoder
30
+
31
+ def _get_vae_decoder_module(self):
32
+ return self.vae.decoder
33
+
34
+ def _get_decoder_module(self):
35
+ return self.t5.decoder
36
+
37
+ def setup(self):
38
+ self.t5 = FlaxT5ForConditionalGenerationModule(self.config.t5)
39
+ self.vae = VAE(self.config)
40
+
41
+ def __call__(
42
+ self,
43
+ input_ids=None,
44
+ attention_mask=None,
45
+ decoder_input_ids=None,
46
+ decoder_attention_mask=None,
47
+ encoder_outputs=None,
48
+ latent_codes=None,
49
+ output_attentions=None,
50
+ output_hidden_states=None,
51
+ return_dict=None,
52
+ deterministic: bool = True,
53
+ ):
54
+ """
55
+ Adapted from `FlaxT5ForConditionalGenerationModule`
56
+ """
57
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
58
+
59
+ # Encode
60
+ encoder_outputs = self.t5.encoder(
61
+ input_ids=input_ids,
62
+ attention_mask=attention_mask,
63
+ output_attentions=output_attentions,
64
+ output_hidden_states=output_hidden_states,
65
+ return_dict=return_dict,
66
+ deterministic=deterministic,
67
+ )
68
+
69
+ hidden_states = encoder_outputs[0]
70
+
71
+ # Autoencode
72
+ hidden_states, latent_codes = self.vae(hidden_states, latent_codes)
73
+ encoder_attention_mask = jnp.ones((hidden_states.shape[0], hidden_states.shape[1]))
74
+
75
+ # Decode
76
+ decoder_outputs = self.t5.decoder(
77
+ input_ids=decoder_input_ids,
78
+ attention_mask=decoder_attention_mask,
79
+ encoder_hidden_states=hidden_states,
80
+ encoder_attention_mask=encoder_attention_mask,
81
+ output_attentions=output_attentions,
82
+ output_hidden_states=output_hidden_states,
83
+ return_dict=return_dict,
84
+ deterministic=deterministic,
85
+ )
86
+
87
+ sequence_output = decoder_outputs[0]
88
+
89
+ if self.config.tie_word_embeddings:
90
+ # Rescale output before projecting on vocab
91
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
92
+ sequence_output = sequence_output * (self.config.t5.d_model ** -0.5)
93
+
94
+ if self.t5.config.tie_word_embeddings:
95
+ shared_embedding = self.t5.shared.variables["params"]["embedding"]
96
+ lm_logits = self.t5.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output)
97
+ else:
98
+ lm_logits = self.t5.lm_head(sequence_output)
99
+
100
+ if not return_dict:
101
+ return [lm_logits, latent_codes] + decoder_outputs[1:] + encoder_outputs
102
+
103
+ return TransformerVaeOutput(
104
+ logits=lm_logits,
105
+ latent_codes=latent_codes,
106
+ last_hidden_state=decoder_outputs.last_hidden_state,
107
+ past_key_values=decoder_outputs.past_key_values,
108
+ decoder_hidden_states=decoder_outputs.hidden_states,
109
+ decoder_attentions=decoder_outputs.attentions,
110
+ cross_attentions=decoder_outputs.cross_attentions,
111
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
112
+ encoder_hidden_states=encoder_outputs.hidden_states,
113
+ encoder_attentions=encoder_outputs.attentions,
114
+ )
115
+
116
+
117
+ class FlaxT5VaePreTrainedModel(FlaxPreTrainedModel, VaeFlaxGenerationMixin):
118
+ """
119
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
120
+ models.
121
+ """
122
+
123
+ config_class = T5VaeConfig
124
+ base_model_prefix = "transformer"
125
+ module_class: nn.Module = None
126
+
127
+ def __init__(
128
+ self,
129
+ config: T5VaeConfig,
130
+ input_shape: Tuple[int] = (1, 1),
131
+ seed: int = 0,
132
+ dtype: jnp.dtype = jnp.float32,
133
+ **kwargs
134
+ ):
135
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
136
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
137
+
138
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
139
+ # init input tensors
140
+ input_ids = jnp.zeros(input_shape, dtype="i4")
141
+
142
+ attention_mask = jnp.ones_like(input_ids)
143
+ decoder_input_ids = jnp.ones_like(input_ids)
144
+ decoder_attention_mask = jnp.ones_like(input_ids)
145
+
146
+ params_rng, dropout_rng = jax.random.split(rng)
147
+ rngs = {"params": params_rng, "dropout": dropout_rng}
148
+
149
+ return self.module.init(
150
+ rngs,
151
+ input_ids,
152
+ attention_mask,
153
+ decoder_input_ids,
154
+ decoder_attention_mask,
155
+ )["params"]
156
+
157
+ def __call__(
158
+ self,
159
+ input_ids: jnp.ndarray,
160
+ attention_mask: Optional[jnp.ndarray] = None,
161
+ decoder_input_ids: jnp.ndarray = None,
162
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
163
+ output_attentions: Optional[bool] = None,
164
+ output_hidden_states: Optional[bool] = None,
165
+ return_dict: Optional[bool] = None,
166
+ train: bool = False,
167
+ params: dict = None,
168
+ dropout_rng: PRNGKey = None,
169
+ ):
170
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
171
+ output_hidden_states = (
172
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
173
+ )
174
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
175
+
176
+ if decoder_input_ids is None:
177
+ raise ValueError(
178
+ "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed here."
179
+ )
180
+
181
+ # prepare encoder inputs
182
+ if attention_mask is None:
183
+ attention_mask = jnp.ones_like(input_ids)
184
+
185
+ # prepare decoder inputs
186
+ if decoder_attention_mask is None:
187
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
188
+
189
+ # Handle any PRNG if needed
190
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
191
+
192
+ return self.module.apply(
193
+ {"params": params or self.params},
194
+ input_ids=jnp.array(input_ids, dtype="i4"),
195
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
196
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
197
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
198
+ output_attentions=output_attentions,
199
+ output_hidden_states=output_hidden_states,
200
+ return_dict=return_dict,
201
+ deterministic=not train,
202
+ rngs=rngs,
203
+ )
204
+
205
+ def init_cache(self, batch_size, max_length, latent_codes):
206
+ r"""
207
+ Args:
208
+ batch_size (:obj:`int`):
209
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
210
+ max_length (:obj:`int`):
211
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
212
+ cache.
213
+ latent_codes (:obj:`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
214
+ ``latent_codes`` consists of compressed hidden-states at the output of the last layer of the encoder.
215
+ Used in the cross-attention of the decoder.
216
+ """
217
+ # init input variables to retrieve cache
218
+ decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
219
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
220
+
221
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs):
222
+ decoder_module = module._get_decoder_module()
223
+ return decoder_module(
224
+ decoder_input_ids,
225
+ decoder_attention_mask,
226
+ **kwargs,
227
+ )
228
+
229
+ init_variables = self.module.init(
230
+ jax.random.PRNGKey(0),
231
+ decoder_input_ids=decoder_input_ids,
232
+ decoder_attention_mask=decoder_attention_mask,
233
+ init_cache=True,
234
+ method=_decoder_forward, # we only need to call the decoder to init the cache
235
+ )
236
+ return unfreeze(init_variables["cache"])
237
+
238
+ def encode(
239
+ self,
240
+ input_ids: jnp.ndarray,
241
+ attention_mask: Optional[jnp.ndarray] = None,
242
+ output_attentions: Optional[bool] = None,
243
+ output_hidden_states: Optional[bool] = None,
244
+ return_dict: Optional[bool] = None,
245
+ train: bool = False,
246
+ params: dict = None,
247
+ dropout_rng: PRNGKey = None,
248
+ ):
249
+ raise NotImplementedError()
250
+
251
+ def decode(
252
+ self,
253
+ decoder_input_ids,
254
+ latent_codes,
255
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
256
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
257
+ past_key_values: dict = None,
258
+ output_attentions: Optional[bool] = None,
259
+ output_hidden_states: Optional[bool] = None,
260
+ return_dict: Optional[bool] = None,
261
+ train: bool = False,
262
+ params: dict = None,
263
+ dropout_rng: PRNGKey = None,
264
+ ):
265
+ raise NotImplementedError()
266
+
267
+
268
+ class FlaxT5VaeForAutoencoding(FlaxT5VaePreTrainedModel):
269
+ module_class = FlaxT5VaeForAutoencodingModule
270
+
271
+ def __call__(
272
+ self,
273
+ input_ids: jnp.ndarray,
274
+ attention_mask: Optional[jnp.ndarray] = None,
275
+ decoder_input_ids=None,
276
+ decoder_attention_mask=None,
277
+ output_attentions: Optional[bool] = None,
278
+ output_hidden_states: Optional[bool] = None,
279
+ return_dict: Optional[bool] = None,
280
+ train: bool = False,
281
+ params: dict = None,
282
+ dropout_rng: PRNGKey = None,
283
+ ):
284
+ '''
285
+ Adapted from `FlaxT5PreTrainedModel`
286
+ '''
287
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
288
+ output_hidden_states = (
289
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
290
+ )
291
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
292
+
293
+ if decoder_input_ids is None:
294
+ raise ValueError(
295
+ "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed here."
296
+ )
297
+
298
+ # prepare encoder inputs
299
+ if attention_mask is None:
300
+ attention_mask = jnp.ones_like(input_ids)
301
+
302
+ # prepare decoder inputs
303
+ if decoder_attention_mask is None:
304
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
305
+
306
+ # Handle any PRNG if needed
307
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
308
+
309
+ return self.module.apply(
310
+ {"params": params or self.params},
311
+ input_ids=jnp.array(input_ids, dtype="i4"),
312
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
313
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
314
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
315
+ output_attentions=output_attentions,
316
+ output_hidden_states=output_hidden_states,
317
+ return_dict=return_dict,
318
+ deterministic=not train,
319
+ rngs=rngs,
320
+ )
321
+
322
+ def encode(
323
+ self,
324
+ input_ids: jnp.ndarray,
325
+ attention_mask: Optional[jnp.ndarray] = None,
326
+ output_attentions: Optional[bool] = None,
327
+ output_hidden_states: Optional[bool] = None,
328
+ return_dict: Optional[bool] = None,
329
+ train: bool = False,
330
+ params: dict = None,
331
+ dropout_rng: PRNGKey = None,
332
+ ):
333
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
334
+ output_hidden_states = (
335
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
336
+ )
337
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
338
+
339
+ if attention_mask is None:
340
+ attention_mask = jnp.ones_like(input_ids)
341
+
342
+ # Handle any PRNG if needed
343
+ rngs = {}
344
+ if dropout_rng is not None:
345
+ rngs["dropout"] = dropout_rng
346
+
347
+ def _encoder_forward(module, input_ids, attention_mask, **kwargs):
348
+ encode_module = module._get_encoder_module()
349
+ vae_encoder_module = module._get_vae_encoder_module()
350
+ return vae_encoder_module(encode_module(input_ids, attention_mask, **kwargs)[0])
351
+
352
+ return self.module.apply(
353
+ {"params": params or self.params},
354
+ input_ids=jnp.array(input_ids, dtype="i4"),
355
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
356
+ output_attentions=output_attentions,
357
+ output_hidden_states=output_hidden_states,
358
+ return_dict=return_dict,
359
+ deterministic=not train,
360
+ rngs=rngs,
361
+ method=_encoder_forward,
362
+ )
363
+
364
+ def decode(
365
+ self,
366
+ decoder_input_ids,
367
+ latent_codes,
368
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
369
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
370
+ past_key_values: dict = None,
371
+ output_attentions: Optional[bool] = None,
372
+ output_hidden_states: Optional[bool] = None,
373
+ return_dict: Optional[bool] = None,
374
+ train: bool = False,
375
+ params: dict = None,
376
+ dropout_rng: PRNGKey = None,
377
+ ):
378
+ r"""
379
+ Returns:
380
+
381
+ Example::
382
+
383
+ >>> model = FlaxT5VaeForAutoencoding.from_pretrained('t5-small')
384
+ >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
385
+
386
+ >>> text = "My friends are cool but they eat too many carbs."
387
+ >>> inputs = tokenizer(text, max_length=512, return_tensors='jax')
388
+ >>> latent_codes = model.encode(**inputs)
389
+
390
+ >>> decoder_start_token_id = model.config.decoder_start_token_id
391
+ >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
392
+
393
+ >>> outputs = model.decode(decoder_input_ids, latent_codes)
394
+ >>> last_decoder_hidden_states = outputs.last_hidden_state
395
+ """
396
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
397
+ output_hidden_states = (
398
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
399
+ )
400
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
401
+
402
+ if encoder_attention_mask is None:
403
+ batch_size, sequence_length = latent_codes.shape[:2]
404
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
405
+
406
+ batch_size, sequence_length = decoder_input_ids.shape
407
+ if decoder_attention_mask is None:
408
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
409
+
410
+ # Handle any PRNG if needed
411
+ rngs = {}
412
+ if dropout_rng is not None:
413
+ rngs["dropout"] = dropout_rng
414
+
415
+ inputs = {"params": params or self.params}
416
+
417
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
418
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
419
+ # it can be changed by FlaxT5Attention module
420
+ if past_key_values:
421
+ inputs["cache"] = past_key_values
422
+ mutable = ["cache"]
423
+ else:
424
+ mutable = False
425
+
426
+ def _decoder_forward(module, decoder_input_ids, latent_codes, decoder_attention_mask, **kwargs):
427
+ vae_decoder_module = module._get_vae_decoder_module()
428
+ decoder_module = module._get_decoder_module()
429
+ decoder_outputs = decoder_module(
430
+ decoder_input_ids,
431
+ decoder_attention_mask,
432
+ encoder_hidden_states=vae_decoder_module(latent_codes),
433
+ **kwargs,
434
+ )
435
+ sequence_output = decoder_outputs[0]
436
+
437
+ if self.config.tie_word_embeddings:
438
+ # Rescale output before projecting on vocab
439
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
440
+ sequence_output = sequence_output * (self.config.t5.d_model ** -0.5)
441
+
442
+ if self.config.tie_word_embeddings:
443
+ shared_embedding = module.t5.shared.variables["params"]["embedding"]
444
+ lm_logits = module.t5.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output)
445
+ else:
446
+ lm_logits = module.t5.lm_head(sequence_output)
447
+
448
+ return lm_logits, decoder_outputs
449
+
450
+ outputs = self.module.apply(
451
+ inputs,
452
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
453
+ latent_codes=latent_codes,
454
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
455
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
456
+ output_attentions=output_attentions,
457
+ output_hidden_states=output_hidden_states,
458
+ return_dict=return_dict,
459
+ deterministic=not train,
460
+ rngs=rngs,
461
+ mutable=mutable,
462
+ method=_decoder_forward,
463
+ )
464
+
465
+ if past_key_values is None:
466
+ lm_logits, decoder_outputs = outputs
467
+ else:
468
+ (lm_logits, decoder_outputs), past = outputs
469
+
470
+ if return_dict:
471
+ outputs = FlaxCausalLMOutputWithCrossAttentions(
472
+ logits=lm_logits,
473
+ hidden_states=decoder_outputs.hidden_states,
474
+ attentions=decoder_outputs.attentions,
475
+ cross_attentions=decoder_outputs.cross_attentions,
476
+ )
477
+ else:
478
+ outputs = (lm_logits,) + decoder_outputs[1:]
479
+
480
+ # add updated cache to model output
481
+ if past_key_values is not None and return_dict:
482
+ outputs["past_key_values"] = unfreeze(past["cache"])
483
+ return outputs
484
+ elif past_key_values is not None and not return_dict:
485
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
486
+
487
+ return outputs
488
+
489
+ def prepare_inputs_for_generation(
490
+ self,
491
+ decoder_input_ids,
492
+ max_length,
493
+ attention_mask: Optional[jnp.DeviceArray] = None,
494
+ decoder_attention_mask: Optional[jnp.DeviceArray] = None,
495
+ latent_codes=None,
496
+ **kwargs
497
+ ):
498
+ # initializing the cache
499
+ batch_size, seq_length = decoder_input_ids.shape
500
+
501
+ past_key_values = self.init_cache(batch_size, max_length, latent_codes)
502
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
503
+ # But since the decoder uses a causal mask, those positions are masked anyways.
504
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
505
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
506
+ if decoder_attention_mask is not None:
507
+ extended_attention_mask = jax.lax.dynamic_update_slice(
508
+ extended_attention_mask, decoder_attention_mask, (0, 0)
509
+ )
510
+
511
+ return {
512
+ "past_key_values": past_key_values,
513
+ "latent_codes": latent_codes,
514
+ "encoder_attention_mask": attention_mask,
515
+ "decoder_attention_mask": extended_attention_mask,
516
+ }
517
+
518
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
519
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
520
+ return model_kwargs
t5_vae_flax_alt/src/utils.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+
3
+ import flax.linen as nn
4
+
5
+
6
+ class MLP(nn.Module):
7
+ features: Sequence[int]
8
+
9
+ @nn.compact
10
+ def __call__(self, x):
11
+ for feat in self.features[:-1]:
12
+ x = nn.relu(nn.Dense(feat)(x))
13
+ x = nn.Dense(self.features[-1])(x)
14
+ return x
15
+
16
+
17
+ def assertEqual(actual, expected, msg, first="Got", second="Expected"):
18
+ if actual != expected:
19
+ raise ValueError(msg + f' {first}: "{actual}" {second}: "{expected}"')
20
+
21
+
22
+ def assertIn(actual, expected, msg, first="Got", second="Expected one of"):
23
+ if actual not in expected:
24
+ raise ValueError(msg + f' {first}: "{actual}" {second}: {expected}')
t5_vae_flax_alt/src/vae.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax.numpy as jnp
2
+ import flax.linen as nn
3
+
4
+ from t5_vae_flax_alt.src.encoders import VAE_ENCODER_MODELS
5
+ from t5_vae_flax_alt.src.decoders import VAE_DECODER_MODELS
6
+ from t5_vae_flax_alt.src.config import T5VaeConfig
7
+
8
+
9
+ class VAE(nn.Module):
10
+ # see https://github.com/google/flax#what-does-flax-look-like
11
+ """
12
+ An MMD-VAE used with encoder-decoder models.
13
+ Encodes all token encodings into a single latent & spits them back out.
14
+ """
15
+ config: T5VaeConfig
16
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
17
+
18
+ def setup(self):
19
+ self.encoder = VAE_ENCODER_MODELS[self.config.vae_encoder_model](self.config.latent_token_size, self.config.n_latent_tokens)
20
+ self.decoder = VAE_DECODER_MODELS[self.config.vae_decoder_model](self.config.t5.d_model, self.config.n_latent_tokens)
21
+
22
+ def __call__(self, encoding=None, latent_codes=None):
23
+ latent_codes = self.encode(encoding)
24
+ return self.decode(latent_codes), latent_codes
25
+
26
+ def encode(self, encoding):
27
+ return self.encoder(encoding)
28
+
29
+ def decode(self, latent):
30
+ return self.decoder(latent)