File size: 1,190 Bytes
f19bbda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import mamba_ssm
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig, PreTrainedModel

mamba_config_defaults = mamba_ssm.models.config_mamba.MambaConfig()

class MambaConfig(PretrainedConfig):
    model_type = "mamba"

    def __init__(
        self,
        d_model: int = mamba_config_defaults.d_model,
        fused_add_norm: bool = mamba_config_defaults.fused_add_norm,
        n_layer: int = mamba_config_defaults.n_layer,
        pad_vocab_size_multiple: int = mamba_config_defaults.pad_vocab_size_multiple,
        residual_in_fp32: bool = mamba_config_defaults.residual_in_fp32,
        rms_norm: bool = mamba_config_defaults.rms_norm,
        ssm_cfg: dict = mamba_config_defaults.ssm_cfg,
        vocab_size: int = mamba_config_defaults.vocab_size,
        **kwargs,
    ):
        self.d_model = d_model
        self.fused_add_norm = fused_add_norm
        self.n_layer = n_layer
        self.pad_vocab_size_multiple = pad_vocab_size_multiple
        self.residual_in_fp32 = residual_in_fp32
        self.rms_norm = rms_norm
        self.ssm_cfg = ssm_cfg
        self.vocab_size = vocab_size

        super().__init__(**kwargs)