File size: 1,190 Bytes
f19bbda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import mamba_ssm
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig, PreTrainedModel
mamba_config_defaults = mamba_ssm.models.config_mamba.MambaConfig()
class MambaConfig(PretrainedConfig):
model_type = "mamba"
def __init__(
self,
d_model: int = mamba_config_defaults.d_model,
fused_add_norm: bool = mamba_config_defaults.fused_add_norm,
n_layer: int = mamba_config_defaults.n_layer,
pad_vocab_size_multiple: int = mamba_config_defaults.pad_vocab_size_multiple,
residual_in_fp32: bool = mamba_config_defaults.residual_in_fp32,
rms_norm: bool = mamba_config_defaults.rms_norm,
ssm_cfg: dict = mamba_config_defaults.ssm_cfg,
vocab_size: int = mamba_config_defaults.vocab_size,
**kwargs,
):
self.d_model = d_model
self.fused_add_norm = fused_add_norm
self.n_layer = n_layer
self.pad_vocab_size_multiple = pad_vocab_size_multiple
self.residual_in_fp32 = residual_in_fp32
self.rms_norm = rms_norm
self.ssm_cfg = ssm_cfg
self.vocab_size = vocab_size
super().__init__(**kwargs)
|