import math from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging logger = logging.get_logger(__name__) class MoSMambaConfig(PretrainedConfig): model_type = "MoSMamba" def __init__( self, vocab_size=50280, hidden_size=768, state_size=16, num_selectivities=6, num_selectivities_per_tok=2, num_hidden_layers=32, layer_norm_epsilon=1e-5, pad_token_id=0, bos_token_id=0, eos_token_id=0, expand=2, conv_kernel=4, use_bias=False, use_conv_bias=True, hidden_act="silu", initializer_range=0.1, residual_in_fp32=True, time_step_rank="auto", time_step_scale=1.0, time_step_min=0.001, time_step_max=0.1, time_step_init_scheme="random", time_step_floor=1e-4, rescale_prenorm_residual=False, use_cache=True, **kwargs, ): self.vocab_size = vocab_size self.hidden_size = hidden_size self.state_size = state_size self.num_hidden_layers = num_hidden_layers self.layer_norm_epsilon = layer_norm_epsilon self.conv_kernel = conv_kernel self.expand = expand self.intermediate_size = int(expand * self.hidden_size) self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.pad_token_id = pad_token_id self.use_bias = use_bias self.use_conv_bias = use_conv_bias self.hidden_act = hidden_act self.initializer_range = initializer_range self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank self.time_step_scale = time_step_scale self.time_step_min = time_step_min self.time_step_max = time_step_max self.time_step_init_scheme = time_step_init_scheme self.time_step_floor = time_step_floor self.rescale_prenorm_residual = rescale_prenorm_residual self.residual_in_fp32 = residual_in_fp32 self.use_cache = use_cache self.num_selectivities = num_selectivities self.num_selectivities_per_tok = num_selectivities_per_tok super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs)