File size: 3,203 Bytes
13df84c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
"""Moment model configuration"""
from transformers import PretrainedConfig
from transformers import logging
DEFAULT_T5_CONFIG = {
# "_name_or_path": "google/flan-t5-large",
# "architectures": [
# "T5ForConditionalGeneration"
# ],
"classifier_dropout": 0.0,
"d_ff": 2816,
"d_kv": 64,
"d_model": 1024,
"decoder_start_token_id": 0,
"dense_act_fn": "gelu_new",
"dropout_rate": 0.1,
"eos_token_id": 1,
"feed_forward_proj": "gated-gelu",
"initializer_factor": 1.0,
"is_encoder_decoder": False,
"is_gated_act": True,
"layer_norm_epsilon": 1e-06,
# "model_type": "t5",
"n_positions": 512,
"num_decoder_layers": 24,
"num_heads": 16,
"num_layers": 24,
"output_past": True,
"pad_token_id": 0,
"relative_attention_max_distance": 128,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": False,
# "transformers_version": "4.33.3",
"use_cache": False,
"vocab_size": 32128
}
class MomentConfig(PretrainedConfig):
model_type = "moment"
def __init__(
self,
t5_config: dict = DEFAULT_T5_CONFIG,
d_model: int = None,
seq_len: int = 512,
patch_len: int = 16,
patch_stride_len: int = 16,
dropout: float = 0.1,
revin_num_features: int = 1,
revin_eps: float = 1e-5,
revin_affine: bool = True,
add_positional_embedding: bool = True,
value_embedding_bias: bool = False,
orth_gain: float = 1.41,
mask_ratio: float = 0.15,
freeze_embedder: bool = True,
freeze_encoder: bool = True,
freeze_head: bool = False,
enable_gradient_checkpointing: bool = True,
randomly_initialize_backbone: bool = False,
**kwargs
):
self.t5_config = self._init_t5_config(t5_config)
self.d_model = d_model
self.seq_len = seq_len
self.patch_len = patch_len
self.patch_stride_len = patch_stride_len
self.dropout = dropout
self.revin_num_features = revin_num_features
self.revin_eps = revin_eps
self.revin_affine = revin_affine
self.add_positional_embedding = add_positional_embedding
self.value_embedding_bias = value_embedding_bias
self.orth_gain = orth_gain
self.mask_ratio = mask_ratio
self.freeze_embedder = freeze_embedder
self.freeze_encoder = freeze_encoder
self.freeze_head = freeze_head
self.enable_gradient_checkpointing = enable_gradient_checkpointing
self.randomly_initialize_backbone = randomly_initialize_backbone
self._validation_config()
super().__init__(**kwargs)
def _init_t5_config(self, config: dict):
if config is None:
return DEFAULT_T5_CONFIG
else:
# 与えられたconfigでDEFAULT_T5_CONFIGを更新
updated_config = DEFAULT_T5_CONFIG.copy()
updated_config.update(config)
return updated_config
def _validation_config(self):
"""
Validate configuration.
"""
if self.d_model is None:
self.d_model = self.t5_config["d_model"]
|