t5-vae-python / model /config.py
Fraser's picture
add transformer-vae code
0b69648
raw history blame
No virus
5.38 kB
import copy
from transformers.utils import logging
from transformers.configuration_utils import PretrainedConfig
from transformers import AutoConfig, T5Config
from model.encoders import VAE_ENCODER_MODELS
from model.decoders import VAE_DECODER_MODELS
from model.utils import assertEqual, assertIn
logger = logging.get_logger(__name__)
class T5VaeConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of :class:`FlaxT5VAE`.
It is used to instantiate a T5-VAE model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the T5 `t5-vae-base architecture.
To be able to use `transformer.trainer.Trainer` we need some specific training logic & config in the model.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
Arguments:
n_latent_tokens (:obj:`int`, `optional`, defaults to 6):
Number of latent tokens (must be less than seq length).
latent_token_size (:obj:`int`, `optional`, defaults to 32):
Number of dimensions to use for each latent token.
t5_name (:obj:`str`, `optional`, defaults to t5-base):
Name of the Transformer model to use as a decoder.
block_size (:obj:`int`, `optional`, defaults to 60):
NOTE: Every input sequence must be padded to be equal to this length.
"""
model_type = "transformer_vae"
is_composition = True
def __init__(
self,
t5_model_name_or_path=None,
n_latent_tokens=6, # set to -1 for full sequence
latent_token_size=32,
vae_encoder_model='',
vae_decoder_model='',
block_size=60,
decoder_start_token_id=0,
cache_dir=None,
tie_word_embeddings=True,
# T5 config
t5=dict(),
vocab_size=32128,
d_model=512,
d_kv=64,
d_ff=2048,
num_layers=6,
num_decoder_layers=None,
num_heads=8,
relative_attention_num_buckets=32,
dropout_rate=0.1,
layer_norm_epsilon=1e-6,
initializer_factor=1.0,
feed_forward_proj="relu",
is_encoder_decoder=True,
use_cache=True,
pad_token_id=0,
eos_token_id=1,
gradient_checkpointing=False,
# end
**kwargs,
):
assertIn(vae_encoder_model, VAE_ENCODER_MODELS.keys(), "Unexpected VAE encoder.")
assertIn(vae_decoder_model, VAE_DECODER_MODELS.keys(), "Unexpected VAE decoder.")
super().__init__(**kwargs)
self.set_seq_size = block_size
# VAE
self.vae_encoder_model = vae_encoder_model
self.vae_decoder_model = vae_decoder_model
self.latent_token_size = latent_token_size
assert(n_latent_tokens <= self.set_seq_size, 'Cannot use more latent tokens than input tokens.')
self.n_latent_tokens = n_latent_tokens
self.use_cache = use_cache
# T5
if t5_model_name_or_path:
self.t5 = AutoConfig.from_pretrained(t5_model_name_or_path, cache_dir=cache_dir)
assertEqual(self.t5.model_type, "t5", "Need t5 model type for transformer_decoder.")
self.t5.decoder_start_token_id = decoder_start_token_id
elif t5:
# use for loading a config
self.t5 = T5Config(**t5)
else:
self.t5 = T5Config(
vocab_size=vocab_size,
d_model=d_model,
d_kv=d_kv,
d_ff=d_ff,
num_layers=num_layers,
num_decoder_layers=num_decoder_layers,
num_heads=num_heads,
relative_attention_num_buckets=relative_attention_num_buckets,
dropout_rate=dropout_rate,
layer_norm_epsilon=layer_norm_epsilon,
initializer_factor=initializer_factor,
feed_forward_proj=feed_forward_proj,
is_encoder_decoder=is_encoder_decoder,
use_cache=use_cache,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
gradient_checkpointing=gradient_checkpointing,
**kwargs
)
if self.t5.d_model < self.latent_token_size:
raise Exception('Using larger latent token dimension then T5 hidden dimension.')
# Add t5 config options
self.tie_word_embeddings = tie_word_embeddings
self.t5.tie_word_embeddings = self.tie_word_embeddings
self.t5.use_cache = self.use_cache
self.pad_token_id = pad_token_id
self.eos_token_id = eos_token_id
self.decoder_start_token_id = self.t5.decoder_start_token_id
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig`.
Returns:
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["model_type"] = self.__class__.model_type
output['t5'] = self.t5.to_dict()
return output