| import math | |
| from transformers.configuration_utils import PretrainedConfig | |
| from transformers.utils import logging | |
| logger = logging.get_logger(__name__) | |
| class MoSMambaConfig(PretrainedConfig): | |
| model_type = "MoSMamba" | |
| def __init__( | |
| self, | |
| vocab_size=50280, | |
| hidden_size=768, | |
| state_size=16, | |
| num_selectivities=6, | |
| num_selectivities_per_tok=2, | |
| num_hidden_layers=32, | |
| layer_norm_epsilon=1e-5, | |
| pad_token_id=0, | |
| bos_token_id=0, | |
| eos_token_id=0, | |
| expand=2, | |
| conv_kernel=4, | |
| use_bias=False, | |
| use_conv_bias=True, | |
| hidden_act="silu", | |
| initializer_range=0.1, | |
| residual_in_fp32=True, | |
| time_step_rank="auto", | |
| time_step_scale=1.0, | |
| time_step_min=0.001, | |
| time_step_max=0.1, | |
| time_step_init_scheme="random", | |
| time_step_floor=1e-4, | |
| rescale_prenorm_residual=False, | |
| use_cache=True, | |
| **kwargs, | |
| ): | |
| self.vocab_size = vocab_size | |
| self.hidden_size = hidden_size | |
| self.state_size = state_size | |
| self.num_hidden_layers = num_hidden_layers | |
| self.layer_norm_epsilon = layer_norm_epsilon | |
| self.conv_kernel = conv_kernel | |
| self.expand = expand | |
| self.intermediate_size = int(expand * self.hidden_size) | |
| self.bos_token_id = bos_token_id | |
| self.eos_token_id = eos_token_id | |
| self.pad_token_id = pad_token_id | |
| self.use_bias = use_bias | |
| self.use_conv_bias = use_conv_bias | |
| self.hidden_act = hidden_act | |
| self.initializer_range = initializer_range | |
| self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank | |
| self.time_step_scale = time_step_scale | |
| self.time_step_min = time_step_min | |
| self.time_step_max = time_step_max | |
| self.time_step_init_scheme = time_step_init_scheme | |
| self.time_step_floor = time_step_floor | |
| self.rescale_prenorm_residual = rescale_prenorm_residual | |
| self.residual_in_fp32 = residual_in_fp32 | |
| self.use_cache = use_cache | |
| self.num_selectivities = num_selectivities | |
| self.num_selectivities_per_tok = num_selectivities_per_tok | |
| super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs) |