File size: 2,430 Bytes
8863e88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import math

from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging


logger = logging.get_logger(__name__)


class MoSMambaConfig(PretrainedConfig):

    model_type = "MoSMamba"

    def __init__(

        self,

        vocab_size=50280,

        hidden_size=768,

        state_size=16,

        num_selectivities=6,

        num_selectivities_per_tok=2,

        num_hidden_layers=32,

        layer_norm_epsilon=1e-5,

        pad_token_id=0,

        bos_token_id=0,

        eos_token_id=0,

        expand=2,

        conv_kernel=4,

        use_bias=False,

        use_conv_bias=True,

        hidden_act="silu",

        initializer_range=0.1,

        residual_in_fp32=True,

        time_step_rank="auto",

        time_step_scale=1.0,

        time_step_min=0.001,

        time_step_max=0.1,

        time_step_init_scheme="random",

        time_step_floor=1e-4,

        rescale_prenorm_residual=False,

        use_cache=True,

        **kwargs,

    ):
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.state_size = state_size
        self.num_hidden_layers = num_hidden_layers
        self.layer_norm_epsilon = layer_norm_epsilon
        self.conv_kernel = conv_kernel
        self.expand = expand
        self.intermediate_size = int(expand * self.hidden_size)
        self.bos_token_id = bos_token_id
        self.eos_token_id = eos_token_id
        self.pad_token_id = pad_token_id
        self.use_bias = use_bias
        self.use_conv_bias = use_conv_bias
        self.hidden_act = hidden_act
        self.initializer_range = initializer_range
        self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank
        self.time_step_scale = time_step_scale
        self.time_step_min = time_step_min
        self.time_step_max = time_step_max
        self.time_step_init_scheme = time_step_init_scheme
        self.time_step_floor = time_step_floor
        self.rescale_prenorm_residual = rescale_prenorm_residual
        self.residual_in_fp32 = residual_in_fp32
        self.use_cache = use_cache

        self.num_selectivities = num_selectivities
        self.num_selectivities_per_tok = num_selectivities_per_tok

        super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs)