File size: 1,431 Bytes
ac0b14a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from transformers import T5Config

POSITION_ENCODING_REL_T5_BIAS = "t5_relative_bias"
POSITION_ENCODING_REL_TRANSFORMER_XL = "transformer_xl_relative_encoding"
POSITION_ENCODING_ROTARY = "rotary"
POSITION_ENCODING_ROTARY_RERUN = "rotary_rerun"
POSITION_ENCODING_ROTARY_NEW = "new_rotary"
POSITION_ENCODING_ABS_LEARNED = "abs_learned"
POSITION_ENCODING_ABS_SINUSOID = "abs_sinusoid"
POSITION_ENCODING_ALiBi = "alibi"
POSITION_ENCODING_ALiBi_LEARNED = "alibi_learned"
POSITION_ENCODING_NONE = "none"
POSITION_ENCODING_NONE_WINDOW = "none_window"


class CustomT5Config(T5Config):
    model_type = "custom_decoder_only_t5"

    def __init__(
        self,
        position_encoding_type=POSITION_ENCODING_REL_T5_BIAS,
        **kwargs,
    ):
        if position_encoding_type not in [
            POSITION_ENCODING_ALiBi,
            POSITION_ENCODING_ALiBi_LEARNED,
            POSITION_ENCODING_ABS_LEARNED,
            POSITION_ENCODING_ABS_SINUSOID,
            POSITION_ENCODING_REL_T5_BIAS,
            POSITION_ENCODING_REL_TRANSFORMER_XL,
            POSITION_ENCODING_ROTARY,
            POSITION_ENCODING_ROTARY_NEW,
            POSITION_ENCODING_NONE,
            POSITION_ENCODING_NONE_WINDOW,
        ]:
            raise ValueError(
                f"Invalid position_encoding_type: {position_encoding_type}"
            )
        self.position_encoding_type = position_encoding_type
        super().__init__(**kwargs)