File size: 5,032 Bytes
71b8d08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# by syncdoth: https://github.com/syncdoth/RetNet/blob/main/retnet/configuration_retnet.py

from dataclasses import dataclass
import json

from transformers.configuration_utils import PretrainedConfig


def load_config_from_json(config_file):
    with open(config_file, "r") as f:
        config = json.load(f)
        config = RetNetConfig.from_dict(config)
    return config


@dataclass
class RetNetConfig(PretrainedConfig):
    model_type = "retnet"
    initializer_range: float = 0.02
    activation_fn: str = "gelu"
    dropout: float = 0.0  # dropout probability
    activation_dropout: float = 0.0  # dropout probability after activation in FFN.
    drop_path_rate: float = 0.0
    decoder_embed_dim: int = 768  # decoder embedding dimension
    decoder_value_embed_dim: int = 1280  # decoder value embedding dimension
    decoder_ffn_embed_dim: int = 1280  # decoder embedding dimension for FFN
    decoder_layers: int = 12  # num decoder layers
    decoder_retention_heads: int = 3  # num decoder retention heads
    decoder_normalize_before: bool = True  # apply layernorm before each decoder block
    layernorm_embedding: bool = False  # add layernorm to embedding
    no_scale_embedding: bool = True  # if True, dont scale embeddings
    recurrent_chunk_size: int = 512
    use_lm_decay: bool = False
    use_glu: bool = True  # use GLU instead of FFN
    z_loss_coeff: float = 0.0  # coefficient for z loss: TODO: 1e-4
    deepnorm: bool = False
    subln: bool = True
    use_ffn_rms_norm: bool = False
    layernorm_eps: float = 1e-6
    tie_word_embeddings: bool = False

    def __init__(
        self,
        vocab_size: int = 50257,
        initializer_range: float = 0.02,
        is_decoder: bool = True,
        pad_token_id: int = 0,
        eos_token_id: int = 0,
        output_retentions: bool = False,
        use_cache: bool = True,
        forward_impl: str = "parallel",
        activation_fn: str = "gelu",
        dropout: float = 0.0,  # dropout probability
        activation_dropout: float = 0.0,  # dropout probability after activation in FFN.
        drop_path_rate: float = 0.0,
        decoder_embed_dim: int = 768,  # decoder embedding dimension
        decoder_value_embed_dim: int = 1280,  # decoder value embedding dimension
        decoder_ffn_embed_dim: int = 1280,  # decoder embedding dimension for FFN
        decoder_layers: int = 12,  # num decoder layers
        decoder_retention_heads: int = 3,  # num decoder retention heads
        decoder_normalize_before: bool = True,  # apply layernorm before each decoder block
        layernorm_embedding: bool = False,  # add layernorm to embedding
        no_scale_embedding: bool = True,  # if True, dont scale embeddings
        recurrent_chunk_size: int = 512,
        use_glu: bool = True,  # use GLU instead of FFN
        z_loss_coeff: float = 0.0,  # coefficient for z loss: TODO: 1e-4
        use_lm_decay: bool = False,
        deepnorm: bool = False,
        subln: bool = True,
        use_ffn_rms_norm: bool = False,  # use RMSNorm instead of LayerNorm in FFN
        layernorm_eps: float = 1e-6,
        tie_word_embeddings: bool = False,
        **kwargs
    ):
        self.vocab_size = vocab_size
        self.initializer_range = initializer_range
        self.output_retentions = output_retentions
        self.use_lm_decay = use_lm_decay
        self.use_glu = use_glu
        self.z_loss_coeff = z_loss_coeff
        # size related
        self.decoder_embed_dim = decoder_embed_dim
        self.decoder_value_embed_dim = decoder_value_embed_dim
        self.decoder_retention_heads = decoder_retention_heads
        self.decoder_ffn_embed_dim = decoder_ffn_embed_dim
        self.decoder_layers = decoder_layers
        # normalization related
        self.decoder_normalize_before = decoder_normalize_before
        self.activation_fn = activation_fn
        self.dropout = dropout
        self.drop_path_rate = drop_path_rate
        self.activation_dropout = activation_dropout
        self.no_scale_embedding = no_scale_embedding
        self.layernorm_embedding = layernorm_embedding
        self.deepnorm = deepnorm
        self.subln = subln
        self.use_ffn_rms_norm = use_ffn_rms_norm
        self.layernorm_eps = layernorm_eps
        # Blockwise
        self.recurrent_chunk_size = recurrent_chunk_size
        self.forward_impl = forward_impl

        if self.deepnorm:
            self.decoder_normalize_before = False
            self.subln = False
        if self.subln:
            self.decoder_normalize_before = True
            self.deepnorm = False

        super().__init__(
            is_decoder=is_decoder,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            use_cache=use_cache,
            tie_word_embeddings=tie_word_embeddings,
            **kwargs
        )

    def override(self, args):
        for hp in self.__dict__.keys():
            if getattr(args, hp, None) is not None:
                self.__dict__[hp] = getattr(args, hp, None)