Fraser commited on
Commit
0b69648
1 Parent(s): dfe0088

add transformer-vae code

Browse files
Files changed (14) hide show
  1. .gitignore +3 -0
  2. README.md +36 -0
  3. check_install.py +15 -0
  4. model/__init__.py +0 -0
  5. model/config.py +137 -0
  6. model/decoders.py +23 -0
  7. model/encoders.py +26 -0
  8. model/outputs.py +74 -0
  9. model/t5_vae.py +522 -0
  10. model/utils.py +24 -0
  11. model/vae.py +30 -0
  12. requirements.txt +3 -0
  13. train.py +706 -0
  14. train.sh +0 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ .vscode
2
+ venv
3
+ *.pyc
README.md ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Transformer-VAE (flax) (WIP)
2
+
3
+ A Transformer-VAE made using flax.
4
+
5
+ Done as part of Huggingface community training ([see forum post](https://discuss.huggingface.co/t/train-a-vae-to-interpolate-on-english-sentences/7548)).
6
+
7
+ Builds on T5, using an autoencoder to convert it into a VAE.
8
+
9
+ [See training logs.](https://wandb.ai/fraser/flax-vae)
10
+
11
+ ## ToDo
12
+
13
+ - [ ] Basic training script working. (Fraser + Theo)
14
+ - [ ] Add MMD loss (Theo)
15
+
16
+ - [ ] Save a wikipedia sentences dataset to Huggingface (see original https://github.com/ChunyuanLI/Optimus/blob/master/data/download_datasets.md) (Mina)
17
+ - [ ] Make a tokenizer using the OPTIMUS tokenized dataset.
18
+ - [ ] Train on the OPTIMUS wikipedia sentences dataset.
19
+
20
+ - [ ] Make Huggingface widget interpolating sentences! (???) https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects#how-to-build-a-demo
21
+
22
+ Optional ToDos:
23
+
24
+ - [ ] Add Funnel transformer encoder to FLAX (don't need weights).
25
+ - [ ] Train a Funnel-encoder + T5-decoder transformer VAE.
26
+
27
+ - [ ] Additional datasets:
28
+ - [ ] Poetry (https://www.gwern.net/GPT-2#data-the-project-gutenberg-poetry-corpus)
29
+ - [ ] 8-bit music (https://github.com/chrisdonahue/LakhNES)
30
+
31
+ ## Setup
32
+
33
+ Follow all steps to install dependencies from https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm
34
+
35
+ - [ ] Find dataset storage site.
36
+ - [ ] Ask JAX team for dataset storage.
check_install.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import FlaxRobertaModel, RobertaTokenizerFast
2
+ from datasets import load_dataset
3
+ import jax
4
+
5
+ dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True)
6
+
7
+ dummy_input = next(iter(dataset))["text"]
8
+
9
+ tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
10
+ input_ids = tokenizer(dummy_input, return_tensors="np").input_ids[:, :10]
11
+
12
+ model = FlaxRobertaModel.from_pretrained("julien-c/dummy-unknown")
13
+
14
+ # run a forward pass, should return an object `FlaxBaseModelOutputWithPooling`
15
+ z = model(input_ids)
model/__init__.py ADDED
File without changes
model/config.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from transformers.utils import logging
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers import AutoConfig, T5Config
5
+
6
+ from model.encoders import VAE_ENCODER_MODELS
7
+ from model.decoders import VAE_DECODER_MODELS
8
+ from model.utils import assertEqual, assertIn
9
+
10
+ logger = logging.get_logger(__name__)
11
+
12
+
13
+ class T5VaeConfig(PretrainedConfig):
14
+ r"""
15
+ This is the configuration class to store the configuration of :class:`FlaxT5VAE`.
16
+ It is used to instantiate a T5-VAE model according to the specified arguments, defining the model architecture.
17
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the T5 `t5-vae-base architecture.
18
+
19
+ To be able to use `transformer.trainer.Trainer` we need some specific training logic & config in the model.
20
+
21
+ Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
22
+ outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
23
+
24
+ Arguments:
25
+ n_latent_tokens (:obj:`int`, `optional`, defaults to 6):
26
+ Number of latent tokens (must be less than seq length).
27
+ latent_token_size (:obj:`int`, `optional`, defaults to 32):
28
+ Number of dimensions to use for each latent token.
29
+ t5_name (:obj:`str`, `optional`, defaults to t5-base):
30
+ Name of the Transformer model to use as a decoder.
31
+ block_size (:obj:`int`, `optional`, defaults to 60):
32
+ NOTE: Every input sequence must be padded to be equal to this length.
33
+ """
34
+ model_type = "transformer_vae"
35
+ is_composition = True
36
+
37
+ def __init__(
38
+ self,
39
+ t5_model_name_or_path=None,
40
+ n_latent_tokens=6, # set to -1 for full sequence
41
+ latent_token_size=32,
42
+ vae_encoder_model='',
43
+ vae_decoder_model='',
44
+ block_size=60,
45
+ decoder_start_token_id=0,
46
+ cache_dir=None,
47
+ tie_word_embeddings=True,
48
+ # T5 config
49
+ t5=dict(),
50
+ vocab_size=32128,
51
+ d_model=512,
52
+ d_kv=64,
53
+ d_ff=2048,
54
+ num_layers=6,
55
+ num_decoder_layers=None,
56
+ num_heads=8,
57
+ relative_attention_num_buckets=32,
58
+ dropout_rate=0.1,
59
+ layer_norm_epsilon=1e-6,
60
+ initializer_factor=1.0,
61
+ feed_forward_proj="relu",
62
+ is_encoder_decoder=True,
63
+ use_cache=True,
64
+ pad_token_id=0,
65
+ eos_token_id=1,
66
+ gradient_checkpointing=False,
67
+ # end
68
+ **kwargs,
69
+ ):
70
+ assertIn(vae_encoder_model, VAE_ENCODER_MODELS.keys(), "Unexpected VAE encoder.")
71
+ assertIn(vae_decoder_model, VAE_DECODER_MODELS.keys(), "Unexpected VAE decoder.")
72
+
73
+ super().__init__(**kwargs)
74
+
75
+ self.set_seq_size = block_size
76
+
77
+ # VAE
78
+ self.vae_encoder_model = vae_encoder_model
79
+ self.vae_decoder_model = vae_decoder_model
80
+
81
+ self.latent_token_size = latent_token_size
82
+ assert(n_latent_tokens <= self.set_seq_size, 'Cannot use more latent tokens than input tokens.')
83
+ self.n_latent_tokens = n_latent_tokens
84
+ self.use_cache = use_cache
85
+
86
+ # T5
87
+ if t5_model_name_or_path:
88
+ self.t5 = AutoConfig.from_pretrained(t5_model_name_or_path, cache_dir=cache_dir)
89
+ assertEqual(self.t5.model_type, "t5", "Need t5 model type for transformer_decoder.")
90
+ self.t5.decoder_start_token_id = decoder_start_token_id
91
+ elif t5:
92
+ # use for loading a config
93
+ self.t5 = T5Config(**t5)
94
+ else:
95
+ self.t5 = T5Config(
96
+ vocab_size=vocab_size,
97
+ d_model=d_model,
98
+ d_kv=d_kv,
99
+ d_ff=d_ff,
100
+ num_layers=num_layers,
101
+ num_decoder_layers=num_decoder_layers,
102
+ num_heads=num_heads,
103
+ relative_attention_num_buckets=relative_attention_num_buckets,
104
+ dropout_rate=dropout_rate,
105
+ layer_norm_epsilon=layer_norm_epsilon,
106
+ initializer_factor=initializer_factor,
107
+ feed_forward_proj=feed_forward_proj,
108
+ is_encoder_decoder=is_encoder_decoder,
109
+ use_cache=use_cache,
110
+ pad_token_id=pad_token_id,
111
+ eos_token_id=eos_token_id,
112
+ gradient_checkpointing=gradient_checkpointing,
113
+ **kwargs
114
+ )
115
+
116
+ if self.t5.d_model < self.latent_token_size:
117
+ raise Exception('Using larger latent token dimension then T5 hidden dimension.')
118
+
119
+ # Add t5 config options
120
+ self.tie_word_embeddings = tie_word_embeddings
121
+ self.t5.tie_word_embeddings = self.tie_word_embeddings
122
+ self.t5.use_cache = self.use_cache
123
+ self.pad_token_id = pad_token_id
124
+ self.eos_token_id = eos_token_id
125
+ self.decoder_start_token_id = self.t5.decoder_start_token_id
126
+
127
+ def to_dict(self):
128
+ """
129
+ Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig`.
130
+
131
+ Returns:
132
+ :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
133
+ """
134
+ output = copy.deepcopy(self.__dict__)
135
+ output["model_type"] = self.__class__.model_type
136
+ output['t5'] = self.t5.to_dict()
137
+ return output
model/decoders.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import flax.linen as nn
3
+
4
+ logger = logging.getLogger(__name__)
5
+
6
+
7
+ class Decoder(nn.Module):
8
+ '''
9
+ Converts latent code -> transformer encoding.
10
+ '''
11
+ dim_model: int
12
+ n_latent_tokens: int
13
+
14
+ @nn.compact
15
+ def __call__(self, latent_code): # (batch, latent_tokens_per_sequence, latent_token_dim)
16
+ raw_latent_tokens = nn.Dense(self.dim_model)(latent_code)
17
+ latent_tokens = nn.LayerNorm()(raw_latent_tokens)
18
+ return latent_tokens # (batch, latent_tokens_per_sequence, dim_model)
19
+
20
+
21
+ VAE_DECODER_MODELS = {
22
+ '': Decoder,
23
+ }
model/encoders.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import jax.numpy as jnp
3
+ import flax.linen as nn
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ class Encoder(nn.Module):
9
+ '''
10
+ Converts N hidden tokens into N seperate latent codes.
11
+ '''
12
+ latent_token_size: int
13
+ n_latent_tokens: int
14
+
15
+ @nn.compact
16
+ def __call__(self, encoding):
17
+ latent_tokens = nn.Dense(self.latent_token_size)(encoding)
18
+ raw_latent_code = latent_tokens[:, : self.n_latent_tokens, :]
19
+ # TODO does this just apply tanh to each latent token? Or across the whole batch
20
+ latent_code = jnp.tanh(raw_latent_code)
21
+ return latent_code # (batch, latent_tokens_per_sequence, latent_token_dim)
22
+
23
+
24
+ VAE_ENCODER_MODELS = {
25
+ '': Encoder,
26
+ }
model/outputs.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import flax
4
+ import jaxlib.xla_extension as jax_xla
5
+
6
+ from transformers.file_utils import ModelOutput
7
+
8
+
9
+ @flax.struct.dataclass
10
+ class TransformerVaeOutput(ModelOutput):
11
+ """
12
+ Base class for a Transformer-VAE's outputs.
13
+
14
+ Args:
15
+ latent_codes (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_latent_tokens, latent_token_size)`):
16
+ Latent codes representing encoded sequences.
17
+ remade_encoder_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_tokens, model_dim)`):
18
+ Reconstructed encoder hidden states representing sequences.
19
+
20
+ (std Seq2Seq) Args:
21
+ logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
22
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
23
+ past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
24
+ Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2
25
+ tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional
26
+ tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
27
+
28
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
29
+ blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
30
+ last_hidden_state (:obj:`tuple(jax_xla.DeviceArray)`:
31
+ Last model hidden state.
32
+ decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
33
+ Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
34
+ layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
35
+
36
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
37
+ decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
38
+ Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
39
+ sequence_length, sequence_length)`.
40
+
41
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
42
+ self-attention heads.
43
+ cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
44
+ Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
45
+ sequence_length, sequence_length)`.
46
+
47
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
48
+ weighted average in the cross-attention heads.
49
+ encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
50
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
51
+ encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
52
+ Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
53
+ layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
54
+
55
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
56
+ encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
57
+ Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
58
+ sequence_length, sequence_length)`.
59
+
60
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
61
+ self-attention heads.
62
+ """
63
+ logits: jax_xla.DeviceArray = None
64
+ latent_codes: jax_xla.DeviceArray = None
65
+ remade_encoder_hidden_state: jax_xla.DeviceArray = None
66
+ # seq2seq
67
+ past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
68
+ decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
69
+ decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
70
+ cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
71
+ last_hidden_state: Optional[jax_xla.DeviceArray] = None
72
+ encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None
73
+ encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
74
+ encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
model/t5_vae.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from jax.random import PRNGKey
6
+ import flax.linen as nn
7
+ from flax.core.frozen_dict import FrozenDict, unfreeze
8
+
9
+ from transformers.modeling_flax_outputs import FlaxCausalLMOutputWithCrossAttentions
10
+ from transformers.file_utils import add_start_docstrings
11
+ from transformers.modeling_flax_utils import FlaxPreTrainedModel
12
+ from transformers.models.t5.modeling_flax_t5 import FlaxT5ForConditionalGenerationModule
13
+
14
+ from model.vae import VAE
15
+ from model.outputs import TransformerVaeOutput
16
+ from model.config import T5VaeConfig
17
+
18
+
19
+ @add_start_docstrings("""T5 Model with a `language modeling` head on top converted into a VAE.""")
20
+ class FlaxT5VaeForAutoencodingModule(nn.Module):
21
+ config: T5VaeConfig
22
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
23
+
24
+ def _get_encoder_module(self):
25
+ return self.t5.encoder
26
+
27
+ def _get_vae_encoder_module(self):
28
+ return self.vae.encoder
29
+
30
+ def _get_vae_decoder_module(self):
31
+ return self.vae.decoder
32
+
33
+ def _get_decoder_module(self):
34
+ return self.t5.decoder
35
+
36
+ def setup(self):
37
+ self.t5 = FlaxT5ForConditionalGenerationModule(self.config.t5)
38
+ self.vae = VAE(self.config)
39
+
40
+ def __call__(
41
+ self,
42
+ input_ids=None,
43
+ attention_mask=None,
44
+ decoder_input_ids=None,
45
+ decoder_attention_mask=None,
46
+ encoder_outputs=None,
47
+ latent_codes=None,
48
+ output_attentions=None,
49
+ output_hidden_states=None,
50
+ return_dict=None,
51
+ deterministic: bool = True,
52
+ ):
53
+ """
54
+ Adapted from `FlaxT5ForConditionalGenerationModule`
55
+ """
56
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
57
+
58
+ # Encode
59
+ encoder_outputs = self.t5.encoder(
60
+ input_ids=input_ids,
61
+ attention_mask=attention_mask,
62
+ output_attentions=output_attentions,
63
+ output_hidden_states=output_hidden_states,
64
+ return_dict=return_dict,
65
+ deterministic=deterministic,
66
+ )
67
+
68
+ hidden_states = encoder_outputs[0]
69
+
70
+ # Autoencode
71
+ hidden_states, latent_codes = self.vae(hidden_states, latent_codes)
72
+ encoder_attention_mask = jnp.ones((hidden_states.shape[0], hidden_states.shape[1]))
73
+
74
+ # Decode
75
+ decoder_outputs = self.t5.decoder(
76
+ input_ids=decoder_input_ids,
77
+ attention_mask=decoder_attention_mask,
78
+ encoder_hidden_states=hidden_states,
79
+ encoder_attention_mask=encoder_attention_mask,
80
+ output_attentions=output_attentions,
81
+ output_hidden_states=output_hidden_states,
82
+ return_dict=return_dict,
83
+ deterministic=deterministic,
84
+ )
85
+
86
+ sequence_output = decoder_outputs[0]
87
+
88
+ if self.t5.config.tie_word_embeddings:
89
+ # Rescale output before projecting on vocab
90
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
91
+ sequence_output = sequence_output * (self.t5.config.d_model ** -0.5)
92
+
93
+ if self.t5.config.tie_word_embeddings:
94
+ shared_embedding = self.t5.shared.variables["params"]["embedding"]
95
+ lm_logits = self.t5.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output)
96
+ else:
97
+ lm_logits = self.t5.lm_head(sequence_output)
98
+
99
+ if not return_dict:
100
+ return [lm_logits, latent_codes] + decoder_outputs[1:] + encoder_outputs
101
+
102
+ return TransformerVaeOutput(
103
+ logits=lm_logits,
104
+ latent_codes=latent_codes,
105
+ last_hidden_state=decoder_outputs.last_hidden_state,
106
+ past_key_values=decoder_outputs.past_key_values,
107
+ decoder_hidden_states=decoder_outputs.hidden_states,
108
+ decoder_attentions=decoder_outputs.attentions,
109
+ cross_attentions=decoder_outputs.cross_attentions,
110
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
111
+ encoder_hidden_states=encoder_outputs.hidden_states,
112
+ encoder_attentions=encoder_outputs.attentions,
113
+ )
114
+
115
+
116
+ class FlaxT5VaePreTrainedModel(FlaxPreTrainedModel):
117
+ """
118
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
119
+ models.
120
+ """
121
+
122
+ config_class = T5VaeConfig
123
+ base_model_prefix = "transformer"
124
+ module_class: nn.Module = None
125
+
126
+ def __init__(
127
+ self,
128
+ config: T5VaeConfig,
129
+ input_shape: Tuple[int] = (1, 1),
130
+ seed: int = 0,
131
+ dtype: jnp.dtype = jnp.float32,
132
+ **kwargs
133
+ ):
134
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
135
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
136
+
137
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
138
+ # init input tensors
139
+ input_ids = jnp.zeros(input_shape, dtype="i4")
140
+
141
+ attention_mask = jnp.ones_like(input_ids)
142
+ decoder_input_ids = jnp.ones_like(input_ids)
143
+ decoder_attention_mask = jnp.ones_like(input_ids)
144
+
145
+ params_rng, dropout_rng = jax.random.split(rng)
146
+ rngs = {"params": params_rng, "dropout": dropout_rng}
147
+
148
+ return self.module.init(
149
+ rngs,
150
+ input_ids,
151
+ attention_mask,
152
+ decoder_input_ids,
153
+ decoder_attention_mask,
154
+ )["params"]
155
+
156
+ def __call__(
157
+ self,
158
+ input_ids: jnp.ndarray,
159
+ attention_mask: Optional[jnp.ndarray] = None,
160
+ decoder_input_ids: jnp.ndarray = None,
161
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
162
+ output_attentions: Optional[bool] = None,
163
+ output_hidden_states: Optional[bool] = None,
164
+ return_dict: Optional[bool] = None,
165
+ train: bool = False,
166
+ params: dict = None,
167
+ dropout_rng: PRNGKey = None,
168
+ ):
169
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
170
+ output_hidden_states = (
171
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
172
+ )
173
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
174
+
175
+ if decoder_input_ids is None:
176
+ raise ValueError(
177
+ "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed here."
178
+ )
179
+
180
+ # prepare encoder inputs
181
+ if attention_mask is None:
182
+ attention_mask = jnp.ones_like(input_ids)
183
+
184
+ # prepare decoder inputs
185
+ if decoder_attention_mask is None:
186
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
187
+
188
+ # Handle any PRNG if needed
189
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
190
+
191
+ return self.module.apply(
192
+ {"params": params or self.params},
193
+ input_ids=jnp.array(input_ids, dtype="i4"),
194
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
195
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
196
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
197
+ output_attentions=output_attentions,
198
+ output_hidden_states=output_hidden_states,
199
+ return_dict=return_dict,
200
+ deterministic=not train,
201
+ rngs=rngs,
202
+ )
203
+
204
+ def init_cache(self, batch_size, max_length, latent_codes):
205
+ r"""
206
+ Args:
207
+ batch_size (:obj:`int`):
208
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
209
+ max_length (:obj:`int`):
210
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
211
+ cache.
212
+ latent_codes (:obj:`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
213
+ ``latent_codes`` consists of compressed hidden-states at the output of the last layer of the encoder.
214
+ Used in the cross-attention of the decoder.
215
+ """
216
+ # init input variables to retrieve cache
217
+ decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
218
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
219
+
220
+ def _decoder_forward(module, decoder_input_ids, latent_codes, decoder_attention_mask, **kwargs):
221
+ vae_decoder_module = module._get_vae_decoder_module()
222
+ decoder_module = module._get_decoder_module()
223
+ return decoder_module(
224
+ decoder_input_ids,
225
+ decoder_attention_mask,
226
+ encoder_hidden_states=vae_decoder_module(latent_codes),
227
+ **kwargs,
228
+ )
229
+
230
+ init_variables = self.module.init(
231
+ jax.random.PRNGKey(0),
232
+ decoder_input_ids=decoder_input_ids,
233
+ latent_codes=latent_codes,
234
+ decoder_attention_mask=decoder_attention_mask,
235
+ init_cache=True,
236
+ method=_decoder_forward, # we only need to call the decoder to init the cache
237
+ )
238
+ return unfreeze(init_variables["cache"])
239
+
240
+ def encode(
241
+ self,
242
+ input_ids: jnp.ndarray,
243
+ attention_mask: Optional[jnp.ndarray] = None,
244
+ output_attentions: Optional[bool] = None,
245
+ output_hidden_states: Optional[bool] = None,
246
+ return_dict: Optional[bool] = None,
247
+ train: bool = False,
248
+ params: dict = None,
249
+ dropout_rng: PRNGKey = None,
250
+ ):
251
+ raise NotImplementedError()
252
+
253
+ def decode(
254
+ self,
255
+ decoder_input_ids,
256
+ latent_codes,
257
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
258
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
259
+ past_key_values: dict = None,
260
+ output_attentions: Optional[bool] = None,
261
+ output_hidden_states: Optional[bool] = None,
262
+ return_dict: Optional[bool] = None,
263
+ train: bool = False,
264
+ params: dict = None,
265
+ dropout_rng: PRNGKey = None,
266
+ ):
267
+ raise NotImplementedError()
268
+
269
+
270
+ class FlaxT5VaeForAutoencoding(FlaxT5VaePreTrainedModel):
271
+ module_class = FlaxT5VaeForAutoencodingModule
272
+
273
+ def __call__(
274
+ self,
275
+ input_ids: jnp.ndarray,
276
+ attention_mask: Optional[jnp.ndarray] = None,
277
+ decoder_input_ids=None,
278
+ decoder_attention_mask=None,
279
+ output_attentions: Optional[bool] = None,
280
+ output_hidden_states: Optional[bool] = None,
281
+ return_dict: Optional[bool] = None,
282
+ train: bool = False,
283
+ params: dict = None,
284
+ dropout_rng: PRNGKey = None,
285
+ ):
286
+ '''
287
+ Adapted from `FlaxT5PreTrainedModel`
288
+ '''
289
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
290
+ output_hidden_states = (
291
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
292
+ )
293
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
294
+
295
+ if decoder_input_ids is None:
296
+ raise ValueError(
297
+ "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed here."
298
+ )
299
+
300
+ # prepare encoder inputs
301
+ if attention_mask is None:
302
+ attention_mask = jnp.ones_like(input_ids)
303
+
304
+ # prepare decoder inputs
305
+ if decoder_attention_mask is None:
306
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
307
+
308
+ # Handle any PRNG if needed
309
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
310
+
311
+ return self.module.apply(
312
+ {"params": params or self.params},
313
+ input_ids=jnp.array(input_ids, dtype="i4"),
314
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
315
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
316
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
317
+ output_attentions=output_attentions,
318
+ output_hidden_states=output_hidden_states,
319
+ return_dict=return_dict,
320
+ deterministic=not train,
321
+ rngs=rngs,
322
+ )
323
+
324
+ def encode(
325
+ self,
326
+ input_ids: jnp.ndarray,
327
+ attention_mask: Optional[jnp.ndarray] = None,
328
+ output_attentions: Optional[bool] = None,
329
+ output_hidden_states: Optional[bool] = None,
330
+ return_dict: Optional[bool] = None,
331
+ train: bool = False,
332
+ params: dict = None,
333
+ dropout_rng: PRNGKey = None,
334
+ ):
335
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
336
+ output_hidden_states = (
337
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
338
+ )
339
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
340
+
341
+ if attention_mask is None:
342
+ attention_mask = jnp.ones_like(input_ids)
343
+
344
+ # Handle any PRNG if needed
345
+ rngs = {}
346
+ if dropout_rng is not None:
347
+ rngs["dropout"] = dropout_rng
348
+
349
+ def _encoder_forward(module, input_ids, attention_mask, **kwargs):
350
+ encode_module = module._get_encoder_module()
351
+ vae_encoder_module = module._get_vae_encoder_module()
352
+ return vae_encoder_module(encode_module(input_ids, attention_mask, **kwargs)[0])
353
+
354
+ return self.module.apply(
355
+ {"params": params or self.params},
356
+ input_ids=jnp.array(input_ids, dtype="i4"),
357
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
358
+ output_attentions=output_attentions,
359
+ output_hidden_states=output_hidden_states,
360
+ return_dict=return_dict,
361
+ deterministic=not train,
362
+ rngs=rngs,
363
+ method=_encoder_forward,
364
+ )
365
+
366
+ def decode(
367
+ self,
368
+ decoder_input_ids,
369
+ latent_codes,
370
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
371
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
372
+ past_key_values: dict = None,
373
+ output_attentions: Optional[bool] = None,
374
+ output_hidden_states: Optional[bool] = None,
375
+ return_dict: Optional[bool] = None,
376
+ train: bool = False,
377
+ params: dict = None,
378
+ dropout_rng: PRNGKey = None,
379
+ ):
380
+ r"""
381
+ Returns:
382
+
383
+ Example::
384
+
385
+ >>> model = FlaxT5VaeForAutoencoding.from_pretrained('t5-small')
386
+ >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
387
+
388
+ >>> text = "My friends are cool but they eat too many carbs."
389
+ >>> inputs = tokenizer(text, max_length=512, return_tensors='jax')
390
+ >>> latent_codes = model.encode(**inputs)
391
+
392
+ >>> decoder_start_token_id = model.config.decoder_start_token_id
393
+ >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
394
+
395
+ >>> outputs = model.decode(decoder_input_ids, latent_codes)
396
+ >>> last_decoder_hidden_states = outputs.last_hidden_state
397
+ """
398
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
399
+ output_hidden_states = (
400
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
401
+ )
402
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
403
+
404
+ if encoder_attention_mask is None:
405
+ batch_size, sequence_length = latent_codes.shape[:2]
406
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
407
+
408
+ batch_size, sequence_length = decoder_input_ids.shape
409
+ if decoder_attention_mask is None:
410
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
411
+
412
+ # Handle any PRNG if needed
413
+ rngs = {}
414
+ if dropout_rng is not None:
415
+ rngs["dropout"] = dropout_rng
416
+
417
+ inputs = {"params": params or self.params}
418
+
419
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
420
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
421
+ # it can be changed by FlaxT5Attention module
422
+ if past_key_values:
423
+ inputs["cache"] = past_key_values
424
+ mutable = ["cache"]
425
+ else:
426
+ mutable = False
427
+
428
+ def _decoder_forward(module, decoder_input_ids, latent_codes, decoder_attention_mask, **kwargs):
429
+ vae_decoder_module = module._get_vae_decoder_module()
430
+ decoder_module = module._get_decoder_module()
431
+ decoder_outputs = decoder_module(
432
+ decoder_input_ids,
433
+ decoder_attention_mask,
434
+ encoder_hidden_states=vae_decoder_module(latent_codes),
435
+ **kwargs,
436
+ )
437
+ sequence_output = decoder_outputs[0]
438
+
439
+ if self.config.tie_word_embeddings:
440
+ # Rescale output before projecting on vocab
441
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
442
+ sequence_output = sequence_output * (self.config.d_model ** -0.5)
443
+
444
+ if self.config.tie_word_embeddings:
445
+ shared_embedding = module.t5.shared.variables["params"]["embedding"]
446
+ lm_logits = module.t5.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output)
447
+ else:
448
+ lm_logits = module.t5.lm_head(sequence_output)
449
+
450
+ return lm_logits, decoder_outputs
451
+
452
+ outputs = self.module.apply(
453
+ inputs,
454
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
455
+ latent_codes=latent_codes,
456
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
457
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
458
+ output_attentions=output_attentions,
459
+ output_hidden_states=output_hidden_states,
460
+ return_dict=return_dict,
461
+ deterministic=not train,
462
+ rngs=rngs,
463
+ mutable=mutable,
464
+ method=_decoder_forward,
465
+ )
466
+
467
+ if past_key_values is None:
468
+ lm_logits, decoder_outputs = outputs
469
+ else:
470
+ (lm_logits, decoder_outputs), past = outputs
471
+
472
+ if return_dict:
473
+ outputs = FlaxCausalLMOutputWithCrossAttentions(
474
+ logits=lm_logits,
475
+ hidden_states=decoder_outputs.hidden_states,
476
+ attentions=decoder_outputs.attentions,
477
+ cross_attentions=decoder_outputs.cross_attentions,
478
+ )
479
+ else:
480
+ outputs = (lm_logits,) + decoder_outputs[1:]
481
+
482
+ # add updated cache to model output
483
+ if past_key_values is not None and return_dict:
484
+ outputs["past_key_values"] = unfreeze(past["cache"])
485
+ return outputs
486
+ elif past_key_values is not None and not return_dict:
487
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
488
+
489
+ return outputs
490
+
491
+ def prepare_inputs_for_generation(
492
+ self,
493
+ decoder_input_ids,
494
+ max_length,
495
+ attention_mask: Optional[jnp.DeviceArray] = None,
496
+ decoder_attention_mask: Optional[jnp.DeviceArray] = None,
497
+ latent_codes=None,
498
+ **kwargs
499
+ ):
500
+ # initializing the cache
501
+ batch_size, seq_length = decoder_input_ids.shape
502
+
503
+ past_key_values = self.init_cache(batch_size, max_length, latent_codes)
504
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
505
+ # But since the decoder uses a causal mask, those positions are masked anyways.
506
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
507
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
508
+ if decoder_attention_mask is not None:
509
+ extended_attention_mask = jax.lax.dynamic_update_slice(
510
+ extended_attention_mask, decoder_attention_mask, (0, 0)
511
+ )
512
+
513
+ return {
514
+ "past_key_values": past_key_values,
515
+ "latent_codes": latent_codes,
516
+ "encoder_attention_mask": attention_mask,
517
+ "decoder_attention_mask": extended_attention_mask,
518
+ }
519
+
520
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
521
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
522
+ return model_kwargs
model/utils.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+
3
+ import flax.linen as nn
4
+
5
+
6
+ class MLP(nn.Module):
7
+ features: Sequence[int]
8
+
9
+ @nn.compact
10
+ def __call__(self, x):
11
+ for feat in self.features[:-1]:
12
+ x = nn.relu(nn.Dense(feat)(x))
13
+ x = nn.Dense(self.features[-1])(x)
14
+ return x
15
+
16
+
17
+ def assertEqual(actual, expected, msg, first="Got", second="Expected"):
18
+ if actual != expected:
19
+ raise ValueError(msg + f' {first}: "{actual}" {second}: "{expected}"')
20
+
21
+
22
+ def assertIn(actual, expected, msg, first="Got", second="Expected one of"):
23
+ if actual not in expected:
24
+ raise ValueError(msg + f' {first}: "{actual}" {second}: {expected}')
model/vae.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax.numpy as jnp
2
+ import flax.linen as nn
3
+
4
+ from model.encoders import VAE_ENCODER_MODELS
5
+ from model.decoders import VAE_DECODER_MODELS
6
+ from model.config import T5VaeConfig
7
+
8
+
9
+ class VAE(nn.Module):
10
+ # see https://github.com/google/flax#what-does-flax-look-like
11
+ """
12
+ An MMD-VAE used with encoder-decoder models.
13
+ Encodes all token encodings into a single latent & spits them back out.
14
+ """
15
+ config: T5VaeConfig
16
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
17
+
18
+ def setup(self):
19
+ self.encoder = VAE_ENCODER_MODELS[self.config.vae_encoder_model](self.config.latent_token_size, self.config.n_latent_tokens)
20
+ self.decoder = VAE_DECODER_MODELS[self.config.vae_decoder_model](self.config.t5.d_model, self.config.n_latent_tokens)
21
+
22
+ def __call__(self, encoding=None, latent_codes=None):
23
+ latent_codes = self.encode(encoding)
24
+ return self.decode(latent_codes), latent_codes
25
+
26
+ def encode(self, encoding):
27
+ return self.encoder(encoding)
28
+
29
+ def decode(self, latent):
30
+ return self.decoder(latent)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ jax
2
+ jaxlib
3
+ -r requirements-tpu.txt
train.py ADDED
@@ -0,0 +1,706 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Pre-training/Fine-tuning seq2seq models on autoencoding a dataset.
3
+
4
+ TODO:
5
+ - [x] Get this running.
6
+ - [x] Don't make decoder input ids.
7
+ - [ ] Add reg loss
8
+ - [x] calculate MMD loss
9
+ - [ ] schedule MMD loss weight
10
+ - [ ] Add these params to the training arguments.
11
+
12
+ reg_schedule_k (:obj:`float`, `optional`, defaults to 0.0025):
13
+ Multiplied by global_step in a sigmoid, more gradually increase regulariser loss weight.
14
+ reg_schedule_b (:obj:`float`, `optional`, defaults to 6.25):
15
+ Added to global step in sigmoid, further delays increase in regulariser loss weight.
16
+ use_extra_logs (:obj:`bool`, `optional`, defaults to False):
17
+ Store extra logs during each training inference.
18
+
19
+ - [ ] Send the schedule time to the compute_loss method and calculate a coefficient based on that.
20
+ '''
21
+ import logging
22
+ import math
23
+ import os
24
+ import sys
25
+ import time
26
+ from dataclasses import dataclass, field
27
+ from pathlib import Path
28
+ from typing import Callable, Optional
29
+
30
+ import datasets
31
+ from datasets import Dataset, load_dataset
32
+ from tqdm import tqdm
33
+
34
+ import jax
35
+ import jax.numpy as jnp
36
+ import optax
37
+ import transformers
38
+ from flax import jax_utils, traverse_util
39
+ from flax.jax_utils import unreplicate
40
+ from flax.training import train_state
41
+ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
42
+ from transformers import (
43
+ AutoTokenizer,
44
+ HfArgumentParser,
45
+ TrainingArguments,
46
+ is_tensorboard_available,
47
+ )
48
+ from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
49
+ from transformers.testing_utils import CaptureLogger
50
+
51
+ from model.t5_vae import FlaxT5VaeForAutoencoding
52
+ from model.config import T5VaeConfig
53
+
54
+
55
+ logger = logging.getLogger(__name__)
56
+
57
+
58
+ @dataclass
59
+ class ModelArguments:
60
+ """
61
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
62
+ """
63
+
64
+ model_name_or_path: Optional[str] = field(
65
+ default=None,
66
+ metadata={
67
+ "help": "The model checkpoint for weights initialization."
68
+ "Don't set if you want to train a model from scratch."
69
+ },
70
+ )
71
+ t5_model_name_or_path: Optional[str] = field(
72
+ default=None,
73
+ metadata={
74
+ "help": "The T5 model checkpoint for weights initialization."
75
+ "Needed when not starting from a T5-VAE model."
76
+ },
77
+ )
78
+ n_latent_tokens: Optional[int] = field(
79
+ default=6,
80
+ metadata={
81
+ "help": "Number of latent tokens (must be less than seq length)."
82
+ },
83
+ )
84
+ latent_token_size: Optional[int] = field(
85
+ default=32,
86
+ metadata={
87
+ "help": "Number of dimensions to use for each latent token."
88
+ },
89
+ )
90
+ config_path: Optional[str] = field(
91
+ default=None, metadata={"help": "Pretrained config path"}
92
+ )
93
+ tokenizer_name: Optional[str] = field(
94
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
95
+ )
96
+ cache_dir: Optional[str] = field(
97
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
98
+ )
99
+ use_fast_tokenizer: bool = field(
100
+ default=True,
101
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
102
+ )
103
+ dtype: Optional[str] = field(
104
+ default="float32",
105
+ metadata={
106
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
107
+ },
108
+ )
109
+
110
+
111
+ @dataclass
112
+ class DataTrainingArguments:
113
+ """
114
+ Arguments pertaining to what data we are going to input our model for training and eval.
115
+ """
116
+
117
+ dataset_name: Optional[str] = field(
118
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
119
+ )
120
+ dataset_config_name: Optional[str] = field(
121
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
122
+ )
123
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
124
+ validation_file: Optional[str] = field(
125
+ default=None,
126
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
127
+ )
128
+ max_train_samples: Optional[int] = field(
129
+ default=None,
130
+ metadata={
131
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
132
+ "value if set."
133
+ },
134
+ )
135
+ max_eval_samples: Optional[int] = field(
136
+ default=None,
137
+ metadata={
138
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
139
+ "value if set."
140
+ },
141
+ )
142
+ overwrite_cache: bool = field(
143
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
144
+ )
145
+ validation_split_percentage: Optional[int] = field(
146
+ default=5,
147
+ metadata={
148
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
149
+ },
150
+ )
151
+ block_size: Optional[int] = field(
152
+ default=None,
153
+ metadata={
154
+ "help": "Optional input sequence length after tokenization. "
155
+ "The training dataset will be truncated in block of this size for training. "
156
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
157
+ },
158
+ )
159
+ streaming: bool = field(
160
+ default=False, metadata={"help": "Stream the dataset."}
161
+ )
162
+ overwrite_cache: bool = field(
163
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
164
+ )
165
+ preprocessing_num_workers: Optional[int] = field(
166
+ default=None,
167
+ metadata={"help": "The number of processes to use for the preprocessing."},
168
+ )
169
+
170
+ def __post_init__(self):
171
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
172
+ raise ValueError("Need either a dataset name or a training/validation file.")
173
+ else:
174
+ if self.train_file is not None:
175
+ extension = self.train_file.split(".")[-1]
176
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
177
+ if self.validation_file is not None:
178
+ extension = self.validation_file.split(".")[-1]
179
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
180
+
181
+
182
+ class TrainState(train_state.TrainState):
183
+ dropout_rng: jnp.ndarray
184
+
185
+ def replicate(self):
186
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
187
+
188
+
189
+ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
190
+ """
191
+ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
192
+ Shuffle batches if `shuffle` is `True`.
193
+ """
194
+ steps_per_epoch = len(dataset) // batch_size
195
+
196
+ if shuffle:
197
+ batch_idx = jax.random.permutation(rng, len(dataset))
198
+ else:
199
+ batch_idx = jnp.arange(len(dataset))
200
+
201
+ batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
202
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
203
+
204
+ for idx in batch_idx:
205
+ batch = dataset[idx]
206
+ batch = {k: jnp.array(v) for k, v in batch.items()}
207
+
208
+ batch = shard(batch)
209
+
210
+ yield batch
211
+
212
+
213
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
214
+ summary_writer.scalar("train_time", train_time, step)
215
+
216
+ train_metrics = get_metrics(train_metrics)
217
+ for key, vals in train_metrics.items():
218
+ tag = f"train_{key}"
219
+ for i, val in enumerate(vals):
220
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
221
+
222
+
223
+ def write_eval_metric(summary_writer, eval_metrics, step):
224
+ for metric_name, value in eval_metrics.items():
225
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
226
+
227
+
228
+ def create_learning_rate_fn(
229
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
230
+ ) -> Callable[[int], jnp.array]:
231
+ """Returns a linear warmup, linear_decay learning rate function."""
232
+ steps_per_epoch = train_ds_size // train_batch_size
233
+ num_train_steps = steps_per_epoch * num_train_epochs
234
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
235
+ decay_fn = optax.linear_schedule(
236
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
237
+ )
238
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
239
+ return schedule_fn
240
+
241
+
242
+ def main():
243
+ # See all possible arguments in src/transformers/training_args.py
244
+ # or by passing the --help flag to this script.
245
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
246
+
247
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
248
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
249
+ # If we pass only one argument to the script and it's the path to a json file,
250
+ # let's parse it to get our arguments.
251
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
252
+ else:
253
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
254
+
255
+ if (
256
+ os.path.exists(training_args.output_dir)
257
+ and os.listdir(training_args.output_dir)
258
+ and training_args.do_train
259
+ and not training_args.overwrite_output_dir
260
+ ):
261
+ raise ValueError(
262
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
263
+ "Use --overwrite_output_dir to overcome."
264
+ )
265
+
266
+ if data_args.block_size is None:
267
+ raise Exception('Must set block_size so we know what length of sequence to autoencode.')
268
+
269
+ # Make one log on every process with the configuration for debugging.
270
+ logging.basicConfig(
271
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
272
+ datefmt="%m/%d/%Y %H:%M:%S",
273
+ level=logging.INFO,
274
+ )
275
+ # Setup logging, we only want one process per machine to log things on the screen.
276
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
277
+ if jax.process_index() == 0:
278
+ datasets.utils.logging.set_verbosity_warning()
279
+ transformers.utils.logging.set_verbosity_info()
280
+ else:
281
+ datasets.utils.logging.set_verbosity_error()
282
+ transformers.utils.logging.set_verbosity_error()
283
+
284
+ # Set the verbosity to info of the Transformers logger (on main process only):
285
+ logger.info(f"Training/evaluation parameters {training_args}")
286
+
287
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
288
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
289
+ # (the dataset will be downloaded automatically from the datasets Hub).
290
+ #
291
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
292
+ # 'text' is found. You can easily tweak this behavior (see below).
293
+ #
294
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
295
+ # download the dataset.
296
+ if data_args.dataset_name is not None:
297
+ # Downloading and loading a dataset from the hub.
298
+ dataset = load_dataset(
299
+ data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, streaming=data_args.streaming, keep_in_memory=False
300
+ )
301
+
302
+ if "validation" not in dataset.keys():
303
+ dataset["validation"] = load_dataset(
304
+ data_args.dataset_name,
305
+ data_args.dataset_config_name,
306
+ split=f"train[:{data_args.validation_split_percentage}%]",
307
+ cache_dir=model_args.cache_dir,
308
+ )
309
+ dataset["train"] = load_dataset(
310
+ data_args.dataset_name,
311
+ data_args.dataset_config_name,
312
+ split=f"train[{data_args.validation_split_percentage}%:]",
313
+ cache_dir=model_args.cache_dir,
314
+ )
315
+ else:
316
+ data_files = {}
317
+ if data_args.train_file is not None:
318
+ data_files["train"] = data_args.train_file
319
+ if data_args.validation_file is not None:
320
+ data_files["validation"] = data_args.validation_file
321
+ extension = data_args.train_file.split(".")[-1]
322
+ if extension == "txt":
323
+ extension = "text"
324
+ dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
325
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
326
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
327
+
328
+ # Load pretrained model and tokenizer
329
+
330
+ # Distributed training:
331
+ # The .from_pretrained methods guarantee that only one local process can concurrently
332
+ # download model & vocab.
333
+
334
+ if model_args.config_path:
335
+ config = T5VaeConfig.from_pretrained(
336
+ model_args.config_path, cache_dir=model_args.cache_dir
337
+ )
338
+ elif model_args.model_name_or_path:
339
+ config = T5VaeConfig.from_pretrained(
340
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir
341
+ )
342
+ else:
343
+ config = T5VaeConfig(**model_args.__dict__)
344
+ logger.warning("You are instantiating a new config instance from scratch.")
345
+
346
+ if model_args.tokenizer_name:
347
+ tokenizer = AutoTokenizer.from_pretrained(
348
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
349
+ )
350
+ elif model_args.t5_model_name_or_path:
351
+ tokenizer = AutoTokenizer.from_pretrained(
352
+ model_args.t5_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
353
+ )
354
+ else:
355
+ raise ValueError(
356
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
357
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
358
+ )
359
+
360
+ if model_args.model_name_or_path:
361
+ model = FlaxT5VaeForAutoencoding.from_pretrained(
362
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
363
+ )
364
+ # TODO assert token embedding size == len(tokenizer)
365
+ assert(model.params['t5']['shared'].shape[0] == len(tokenizer), "T5 Tokenizer doesn't match T5Vae embedding size.")
366
+ else:
367
+ vocab_size = len(tokenizer)
368
+ config.t5.vocab_size = vocab_size
369
+ config.vocab_size = vocab_size
370
+ logger.info("Training new model from scratch.")
371
+ model = FlaxT5VaeForAutoencoding(
372
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
373
+ )
374
+
375
+ if model_args.add_special_tokens:
376
+ special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
377
+ num_added_tokens = tokenizer.add_special_tokens(special_tokens_dict)
378
+ print('We have added', num_added_tokens, 'tokens to GPT2')
379
+ model.resize_token_embeddings(len(tokenizer))
380
+ assert tokenizer.pad_token == '<PAD>'
381
+
382
+ # Preprocessing the datasets.
383
+ # First we tokenize all the texts.
384
+ if training_args.do_train:
385
+ column_names = dataset["train"].column_names
386
+ else:
387
+ column_names = dataset["validation"].column_names
388
+ text_column_name = "text" if "text" in column_names else column_names[0]
389
+
390
+ # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
391
+ tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
392
+
393
+ def tokenize_function(examples):
394
+ with CaptureLogger(tok_logger) as cl:
395
+ output = tokenizer(examples[text_column_name])
396
+ # clm input could be much much longer than block_size
397
+ if "Token indices sequence length is longer than the" in cl.out:
398
+ tok_logger.warning(
399
+ "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
400
+ )
401
+ return output
402
+
403
+ # remove dataset tasks
404
+ for k in dataset.keys():
405
+ dataset[k].info.task_templates = []
406
+
407
+ tokenized_datasets = dataset.map(
408
+ tokenize_function,
409
+ batched=True,
410
+ num_proc=data_args.preprocessing_num_workers,
411
+ remove_columns=column_names,
412
+ load_from_cache_file=not data_args.overwrite_cache,
413
+ )
414
+
415
+ if data_args.block_size > tokenizer.model_max_length:
416
+ logger.warning(
417
+ f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
418
+ f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
419
+ )
420
+ block_size = min(data_args.block_size, tokenizer.model_max_length)
421
+
422
+ pad_token_id, start_token_id = tokenizer.pad_token_id, config.decoder_start_token_id
423
+
424
+ def clip_texts(examples):
425
+ examples["labels"] = examples["input_ids"].copy()
426
+
427
+ for i, input_ids in enumerate(examples["input_ids"]):
428
+ if len(input_ids) > block_size:
429
+ for k in examples.keys():
430
+ examples[k][i] = examples[k][i][:block_size]
431
+ elif len(input_ids) < block_size:
432
+ delta = block_size - len(input_ids)
433
+ examples['input_ids'][i] = examples['input_ids'][i] + [pad_token_id] * delta
434
+ examples['attention_mask'][i] = examples['attention_mask'][i] + [0] * delta
435
+ examples['labels'][i] = examples['labels'][i] + [-100] * delta
436
+
437
+ return examples
438
+
439
+ logger.info('clip_texts...')
440
+ clipped_lm_datasets = tokenized_datasets.map(
441
+ clip_texts,
442
+ batched=True,
443
+ num_proc=data_args.preprocessing_num_workers,
444
+ load_from_cache_file=not data_args.overwrite_cache,
445
+ )
446
+
447
+ def add_decoder_input_ids(examples):
448
+ arr_input_ids = jnp.array(examples["input_ids"])
449
+ pad = pad_token_id * jnp.ones((arr_input_ids.shape[0], 1), dtype=jnp.int32)
450
+ arr_pad_input_ids = jnp.concatenate((arr_input_ids, pad), axis=1)
451
+ examples['decoder_input_ids'] = shift_tokens_right(arr_pad_input_ids, pad_token_id, start_token_id)
452
+
453
+ arr_attention_mask = jnp.array(examples['attention_mask'])
454
+ ones = jnp.ones((arr_attention_mask.shape[0], 1), dtype=jnp.int32)
455
+ examples['decoder_attention_mask'] = jnp.concatenate((ones, arr_attention_mask), axis=1)
456
+
457
+ for k in ['decoder_input_ids', 'decoder_attention_mask']:
458
+ examples[k] = examples[k].tolist()
459
+
460
+ return examples
461
+
462
+ logger.info('add_decoder_input_ids...')
463
+ lm_datasets = clipped_lm_datasets.map(
464
+ add_decoder_input_ids,
465
+ batched=True,
466
+ num_proc=data_args.preprocessing_num_workers,
467
+ load_from_cache_file=not data_args.overwrite_cache,
468
+ )
469
+
470
+ if training_args.do_train:
471
+ if "train" not in tokenized_datasets:
472
+ raise ValueError("--do_train requires a train dataset")
473
+ train_dataset = lm_datasets["train"]
474
+ if data_args.max_train_samples is not None:
475
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
476
+
477
+ if training_args.do_eval:
478
+ if "validation" not in tokenized_datasets:
479
+ raise ValueError("--do_eval requires a validation dataset")
480
+ eval_dataset = lm_datasets["validation"]
481
+ if data_args.max_eval_samples is not None:
482
+ eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
483
+
484
+ # Enable tensorboard only on the master node
485
+ has_tensorboard = is_tensorboard_available()
486
+ if has_tensorboard and jax.process_index() == 0:
487
+ try:
488
+ from flax.metrics.tensorboard import SummaryWriter
489
+
490
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
491
+ except ImportError as ie:
492
+ has_tensorboard = False
493
+ logger.warning(
494
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
495
+ )
496
+ else:
497
+ logger.warning(
498
+ "Unable to display metrics through TensorBoard because the package is not installed: "
499
+ "Please run pip install tensorboard to enable."
500
+ )
501
+
502
+ # Initialize our training
503
+ rng = jax.random.PRNGKey(training_args.seed)
504
+ rng, dropout_rng = jax.random.split(rng)
505
+
506
+ # Store some constant
507
+ num_epochs = int(training_args.num_train_epochs)
508
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
509
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
510
+ steps_per_epoch = len(train_dataset) // train_batch_size
511
+ total_train_steps = steps_per_epoch * num_epochs
512
+
513
+ # Create learning rate schedule
514
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
515
+ len(train_dataset),
516
+ train_batch_size,
517
+ training_args.num_train_epochs,
518
+ training_args.warmup_steps,
519
+ training_args.learning_rate,
520
+ )
521
+
522
+ # We use Optax's "masking" functionality to not apply weight decay
523
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
524
+ # mask boolean with the same structure as the parameters.
525
+ # The mask is True for parameters that should be decayed.
526
+ # Note that this mask is specifically adapted for FlaxGPT2.
527
+ # For other models, one should correct the layer norm parameter naming
528
+ # accordingly.
529
+ def decay_mask_fn(params):
530
+ flat_params = traverse_util.flatten_dict(params)
531
+ flat_mask = {
532
+ path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")])
533
+ for path in flat_params
534
+ }
535
+ return traverse_util.unflatten_dict(flat_mask)
536
+
537
+ # create adam optimizer
538
+ if training_args.adafactor:
539
+ # We use the default parameters here to initialize adafactor,
540
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
541
+ optimizer = optax.adafactor(
542
+ learning_rate=linear_decay_lr_schedule_fn,
543
+ )
544
+ else:
545
+ optimizer = optax.adamw(
546
+ learning_rate=linear_decay_lr_schedule_fn,
547
+ b1=training_args.adam_beta1,
548
+ b2=training_args.adam_beta2,
549
+ eps=training_args.adam_epsilon,
550
+ weight_decay=training_args.weight_decay,
551
+ mask=decay_mask_fn,
552
+ )
553
+
554
+ # Setup train state
555
+ state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
556
+
557
+ def compute_kernel(x, y):
558
+ x_size = x.shape[0]
559
+ y_size = y.shape[0]
560
+ dim = x.shape[1]
561
+ tiled_x = jnp.repeat(jnp.reshape(x, (x_size, 1, dim)), y_size, axis=1)
562
+ tiled_y = jnp.repeat(jnp.reshape(y, (1, y_size, dim)), x_size, axis=0)
563
+ return jnp.exp(-jnp.mean((tiled_x - tiled_y) ** 2, axis=2) / dim * 1.0)
564
+
565
+ def compute_mmd(x, y):
566
+ x_kernel = compute_kernel(x, x)
567
+ y_kernel = compute_kernel(y, y)
568
+ xy_kernel = compute_kernel(x, y)
569
+ return jnp.mean(x_kernel) + jnp.mean(y_kernel) - 2 * jnp.mean(xy_kernel)
570
+
571
+ def regulariser_loss(latent_codes, rng):
572
+ true_samples = jax.random.normal(rng, latent_codes.shape)
573
+ # return jax.vmap(compute_mmd)(true_samples, latent_codes)
574
+ return compute_mmd(true_samples, latent_codes)
575
+
576
+ def loss_fn(logits, labels, latent_codes, regulariser_rng):
577
+ shift_logits = logits[..., :-1, :]
578
+ loss = optax.softmax_cross_entropy(shift_logits, onehot(labels, logits.shape[-1]))
579
+ reg_loss = regulariser_loss(latent_codes.reshape(-1, latent_codes.shape[-1]), regulariser_rng)
580
+ return loss.mean() + reg_loss.mean()
581
+
582
+ # Define gradient update step fn
583
+ def train_step(state, batch):
584
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
585
+ new_dropout_rng, regulariser_rng = jax.random.split(new_dropout_rng)
586
+
587
+ def compute_loss(params):
588
+ labels = batch.pop("labels")
589
+ outputs = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)
590
+ loss = loss_fn(outputs[0], labels, outputs[1], regulariser_rng)
591
+ return loss
592
+
593
+ grad_fn = jax.value_and_grad(compute_loss)
594
+ loss, grad = grad_fn(state.params)
595
+ grad = jax.lax.pmean(grad, "batch")
596
+
597
+ new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
598
+
599
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
600
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
601
+
602
+ return new_state, metrics
603
+
604
+ # Define eval fn
605
+ def eval_step(params, rng, batch):
606
+ labels = batch.pop("labels")
607
+ logits, latent_codes = model(**batch, params=params, train=False)[:2]
608
+ loss = loss_fn(logits, labels, latent_codes, rng)
609
+
610
+ # summarize metrics
611
+ metrics = {"loss": loss}
612
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
613
+ return metrics
614
+
615
+ # Create parallel version of the train and eval step
616
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
617
+ p_eval_step = jax.pmap(eval_step, "batch")
618
+
619
+ # Replicate the train state on each device
620
+ state = state.replicate()
621
+
622
+ logger.info("***** Running training *****")
623
+ logger.info(f" Num examples = {len(train_dataset)}")
624
+ logger.info(f" Num Epochs = {num_epochs}")
625
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
626
+ logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
627
+ logger.info(f" Total optimization steps = {total_train_steps}")
628
+
629
+ train_time = 0
630
+ train_metrics = []
631
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
632
+ for epoch in epochs:
633
+ # ======================== Training ================================
634
+ train_start = time.time()
635
+
636
+ # Create sampling rng
637
+ rng, input_rng = jax.random.split(rng)
638
+
639
+ # Generate an epoch by shuffling sampling indices from the train dataset
640
+ train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
641
+ steps_per_epoch = len(train_dataset) // train_batch_size
642
+ # train
643
+ for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
644
+ batch = next(train_loader)
645
+ state, train_metric = p_train_step(state, batch)
646
+ train_metrics.append(train_metric)
647
+
648
+ cur_step = epoch * (len(train_dataset) // train_batch_size) + step
649
+
650
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
651
+ # Save metrics
652
+ train_metric = unreplicate(train_metric)
653
+ train_time += time.time() - train_start
654
+ if has_tensorboard and jax.process_index() == 0:
655
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
656
+
657
+ epochs.write(
658
+ f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
659
+ )
660
+
661
+ train_metrics = []
662
+
663
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0:
664
+ # ======================== Evaluating ==============================
665
+ eval_metrics = []
666
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
667
+ eval_steps = len(eval_dataset) // eval_batch_size
668
+ for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
669
+ # Model forward
670
+ batch = next(eval_loader)
671
+ metrics = p_eval_step(state.params, state.dropout_rng, batch)
672
+ eval_metrics.append(metrics)
673
+
674
+ # normalize eval metrics
675
+ eval_metrics = get_metrics(eval_metrics)
676
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
677
+
678
+ try:
679
+ eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
680
+ except OverflowError:
681
+ eval_metrics["perplexity"] = float("inf")
682
+
683
+ # Print metrics and update progress bar
684
+ desc = f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
685
+ epochs.write(desc)
686
+ epochs.desc = desc
687
+
688
+ # Save metrics
689
+ if has_tensorboard and jax.process_index() == 0:
690
+ cur_step = epoch * (len(train_dataset) // train_batch_size)
691
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
692
+
693
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
694
+ # save checkpoint after each epoch and push checkpoint to the hub
695
+ if jax.process_index() == 0:
696
+ params = jax.device_get(unreplicate(state.params))
697
+ model.save_pretrained(
698
+ training_args.output_dir,
699
+ params=params,
700
+ push_to_hub=training_args.push_to_hub,
701
+ commit_message=f"Saving weights and logs of step {cur_step}",
702
+ )
703
+
704
+
705
+ if __name__ == "__main__":
706
+ main()
train.sh ADDED
File without changes