File size: 2,575 Bytes
b3abc18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from typing import Optional,Tuple,List


class DCPythiaConfig(PretrainedConfig):
    model_type = "dcpythia"

    '''
    DCPythiaConfig is a config class for DCPythia, which is adpated from https://github.com/pytorch-labs/gpt-fast/blob/main/model.py#L21
    '''
    def __init__(
        self,
        block_size: int = 2048,
        vocab_size: int = 32000,
        n_layer: int = 32,
        n_head: int = 32,
        dim: int = 2560,
        intermediate_size: int = None,
        n_local_heads: int = -1,
        head_dim: int = 64,
        rope_base: float = 10000,
        norm_eps: float = 1e-5,
        use_gradient_checkpointing: bool = False,
        is_training: bool = False,
        q_chunk_size: int = 128,
        use_dcmha: bool = True,
        use_qk_norm: bool = False ,
        window_size: Optional[int] = 256,
        window_type: Optional[str] = None,
        query_wise: bool = False,
        pad_token_id: Optional[int]= None,
        use_parallel_residual: bool =True,
        use_linear_bias: bool = True,
        rotary_pct: float = 0.25,
        bos_token_id: int =1,
        eos_token_id: int =2,
        tie_word_embeddings: bool =False,
        **kwargs,
    ):
        self.block_size=block_size
        self.vocab_size=vocab_size
        self.n_layer=n_layer
        self.n_head=n_head
        self.dim=dim
        self.intermediate_size=intermediate_size
        self.n_local_heads=n_local_heads
        self.head_dim=head_dim
        self.rope_base=rope_base
        self.norm_eps=norm_eps
        self.use_gradient_checkpointing=use_gradient_checkpointing
        self.is_training=is_training
        self.q_chunk_size=q_chunk_size
        self.use_dcmha=use_dcmha
        self.use_qk_norm=use_qk_norm
        self.window_size=window_size
        self.window_type=window_type
        self.query_wise=query_wise
        self.use_parallel_residual = use_parallel_residual
        self.use_linear_bias = use_linear_bias
        self.rotary_pct = rotary_pct
        # post init
        if self.n_local_heads == -1:
            self.n_local_heads = self.n_head
        if self.intermediate_size is None:
            self.intermediate_size =  4 * self.dim 
        self.head_dim = self.dim // self.n_head

        super().__init__(
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            tie_word_embeddings=tie_word_embeddings,
            **kwargs,
        )