File size: 2,582 Bytes
3921342
 
038bacd
69e76c8
038bacd
 
3921342
038bacd
3921342
fdf47df
 
 
 
 
 
 
 
a44b70f
fdf47df
 
 
 
 
 
 
 
 
 
 
 
da44902
fdf47df
7e9a517
523f4fd
7e9a517
523f4fd
 
3921342
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from transformers import PretrainedConfig

class InfinityFormerConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`InfinityFormerModel`]. It is used to instantiate an
    InfinityFormer model according to the specified arguments, defining the model architecture.
    """
    model_type = "infinity_former"

    def __init__(self, **kwargs):
        self.vocab_size = kwargs.pop("vocab_size", 151669)
        self.hidden_size = kwargs.pop("hidden_size", 768)
        self.num_hidden_layers = kwargs.pop("num_hidden_layers", 54)
        self.num_attention_heads = kwargs.pop("num_attention_heads", 12)
        self.intermediate_size = kwargs.pop("intermediate_size", 3072)
        self.hidden_dropout_prob = kwargs.pop("hidden_dropout_prob", 0.1)
        self.attention_probs_dropout_prob = kwargs.pop("attention_probs_dropout_prob", 0.1)
        self.max_position_embeddings = kwargs.pop("max_position_embeddings", 8192)
        self.initializer_range = kwargs.pop("initializer_range", 0.02)
        self.layer_norm_eps = kwargs.pop("layer_norm_eps", 1e-5)
        self.use_rotary_embeddings = kwargs.pop("use_rotary_embeddings", True)
        self.rotary_embedding_base = kwargs.pop("rotary_embedding_base", 10000)
        self.use_multi_scale_memory = kwargs.pop("use_multi_scale_memory", True)
        self.num_memory_scales = kwargs.pop("num_memory_scales", 3)
        self.memory_compression_ratio = kwargs.pop("memory_compression_ratio", 0.5)
        self.memory_compression_frequency = kwargs.pop("memory_compression_frequency", 100)
        self.kernel_type = kwargs.pop("kernel_type", 'elu')
        self.kernel_epsilon = kwargs.pop("kernel_epsilon", 0.1)
        self.use_gating = kwargs.pop("use_gating", True)
        self.gate_init_bias = kwargs.pop("gate_init_bias", -2.0)
        self.use_memory_attention = kwargs.pop("use_memory_attention", False)
        self.use_gradient_checkpointing = kwargs.pop("use_gradient_checkpointing", False)

        use_return_dict = kwargs.pop("use_return_dict", True)
        super().__init__(**kwargs)
        self.return_dict = use_return_dict

        if self.hidden_size % self.num_attention_heads != 0:
            raise ValueError(
                f"`hidden_size` ({self.hidden_size}) must be a multiple of `num_attention_heads` "
                f"({self.num_attention_heads})"
            )
        if self.kernel_type not in ['elu', 'relu', 'learnable']:
            raise ValueError(f"`kernel_type` must be one of 'elu', 'relu', or 'learnable', got {self.kernel_type}")