File size: 3,960 Bytes
4ae913a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
063d44d
4ae913a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
063d44d
4ae913a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from transformers import PretrainedConfig

class DearthConfig(PretrainedConfig):
    model_type = "dearth"
    def __init__(
        self, 
        max_token_len: int = 8192,
        vocab_size: int = None, # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
        n_layer: int = None,
        n_head: int = None,
        n_kv_head: int = None,
        dim: int = None,
        dim_qk_head = None,
        hidden_dim: int = None,
        multiple_of: int = None,
        dropout_rate: float = 0.0,
        layer_init_factor: float = None,
        residual_factor: float = None, # should > 1.0
        sliding_window_size: int = 4096,
        front_window_size: int = 256,
        use_rotary: bool = True,
        rope_theta: float = 10000.0,
        use_alibi: bool = False,

        mimic_attn_layer: int = None, # 1-based, starting from the bottom; The first layer should be 1, not 0
        mimic_n_head: int = None,
        mimic_n_kv_head: int = None,
        mimic_attn_dropout: float = None,
        mimic_dim_qk_head: int = None,
        mimic_use_rotary: bool = True,
        mimic_use_alibi: bool = False,

        pad_token_id=None,
        bos_token_id=1,
        eos_token_id=2,
        tie_word_embeddings=False,
        **kwargs,
    ):
        self.max_token_len = max_token_len
        self.vocab_size = vocab_size
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_kv_head = n_kv_head
        self.dim = dim
        self.dim_qk_head = dim_qk_head
        self.hidden_dim = hidden_dim
        if hidden_dim is None:
            self.hidden_dim = dim * 4
            #print(f"hidden_dim is not specified. Set to {self.hidden_dim}")
        self.multiple_of = multiple_of
        self.dropout_rate = dropout_rate
        self.layer_init_factor = layer_init_factor
        self.residual_factor = residual_factor
        self.sliding_window_size = sliding_window_size
        self.front_window_size = front_window_size
        self.use_rotary = use_rotary
        self.rope_theta = rope_theta
        self.use_alibi = use_alibi

        self.mimic_attn_layer = mimic_attn_layer
        self.mimic_n_head = mimic_n_head
        self.mimic_n_kv_head = mimic_n_kv_head
        self.mimic_attn_dropout = mimic_attn_dropout
        self.mimic_dim_qk_head = mimic_dim_qk_head
        self.mimic_use_rotary = mimic_use_rotary
        self.mimic_use_alibi = mimic_use_alibi

        if "attn_window_size" in kwargs:
            #print("Warning: attn_window_size is deprecated. Please use sliding_window_size instead !!!!!!!!!!!")
            self.sliding_window_size = kwargs["attn_window_size"]

        super().__init__(
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            tie_word_embeddings=tie_word_embeddings,
            **kwargs,
        )

    def __str__(self) -> str:
        return f"""
        max_token_len = {self.max_token_len}
        vocab_size = {self.vocab_size}
        n_layer = {self.n_layer}
        n_head = {self.n_head}
        n_kv_head = {self.n_kv_head}
        dim = {self.dim}
        dim_qk_head = {self.dim_qk_head}
        hidden_dim = {self.hidden_dim}
        multiple_of = {self.multiple_of}
        dropout_rate = {self.dropout_rate}
        layer_init_factor = {self.layer_init_factor}
        residual_factor = {self.residual_factor}
        sliding_window_size = {self.sliding_window_size}
        front_window_size = {self.front_window_size}
        use_rotary = {self.use_rotary}
        use_alibi = {self.use_alibi}

        mimic_attn_layer = {self.mimic_attn_layer}
        mimic_n_head = {self.mimic_n_head}
        mimic_n_kv_head = {self.mimic_n_kv_head}
        mimic_attn_dropout = {self.mimic_attn_dropout}
        mimic_dim_qk_head = {self.mimic_dim_qk_head}
        mimic_use_rotary = {self.mimic_use_rotary}
        mimic_use_alibi = {self.mimic_use_alibi}
        """