"""Caduceus config for Hugging Face. """ from typing import Optional, Union from transformers import PretrainedConfig class CaduceusConfig(PretrainedConfig): """Config that extends the original MambaConfig with params relevant to bi-directionality and RC equivariance.""" model_type = "caduceus" def __init__( self, # From original MambaConfig d_model: int = 2560, d_intermediate: int = 0, use_mamba2: bool = False, n_layer: int = 64, vocab_size: int = 50277, ssm_cfg: Optional[dict] = None, rms_norm: bool = True, residual_in_fp32: bool = True, fused_add_norm: bool = True, pad_vocab_size_multiple: int = 8, # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm norm_epsilon: float = 1e-5, # Used in init_weights initializer_cfg: Optional[dict] = None, # Caduceus-specific params bidirectional: bool = True, bidirectional_strategy: Union[str, None] = "add", bidirectional_weight_tie: bool = True, rcps: bool = False, complement_map: Optional[dict] = None, # used for RCPSEmbedding / RCPSLMHead pos_embeddings: Optional[str] = None, row_first: Optional[bool] = True, **kwargs, ): super().__init__(**kwargs) self.d_model = d_model self.d_intermediate = d_intermediate self.use_mamba2 = use_mamba2 self.n_layer = n_layer self.vocab_size = vocab_size self.ssm_cfg = ssm_cfg self.rms_norm = rms_norm self.residual_in_fp32 = residual_in_fp32 self.fused_add_norm = fused_add_norm self.pad_vocab_size_multiple = pad_vocab_size_multiple self.norm_epsilon = norm_epsilon self.initializer_cfg = initializer_cfg self.bidirectional = bidirectional self.bidirectional_strategy = bidirectional_strategy self.bidirectional_weight_tie = bidirectional_weight_tie self.rcps = rcps self.complement_map = complement_map self.pos_embeddings = pos_embeddings self.row_first = row_first class AxialCaduceusConfig(PretrainedConfig): """Config that extends the original MambaConfig with params relevant to bi-directionality and RC equivariance.""" model_type = "axial_caduceus" def __init__( self, # From original MambaConfig d_model: int = 2560, d_intermediate: int = 0, use_mamba2: bool = False, n_layer: int = 64, vocab_size: int = 50277, ssm_cfg: Optional[dict] = None, rms_norm: bool = True, residual_in_fp32: bool = True, fused_add_norm: bool = True, pad_vocab_size_multiple: int = 8, # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm norm_epsilon: float = 1e-5, # Used in init_weights initializer_cfg: Optional[dict] = None, # Caduceus-specific params bidirectional: bool = True, bidirectional_strategy: Union[str, None] = "add", bidirectional_weight_tie: bool = True, rcps: bool = False, complement_map: Optional[dict] = None, # used for RCPSEmbedding / RCPSLMHead pos_embeddings: Optional[str] = None, row_first: Optional[bool] = True, **kwargs, ): super().__init__(**kwargs) self.d_model = d_model self.d_intermediate = d_intermediate self.use_mamba2 = use_mamba2 self.n_layer = n_layer self.vocab_size = vocab_size self.ssm_cfg = ssm_cfg self.rms_norm = rms_norm self.residual_in_fp32 = residual_in_fp32 self.fused_add_norm = fused_add_norm self.pad_vocab_size_multiple = pad_vocab_size_multiple self.norm_epsilon = norm_epsilon self.initializer_cfg = initializer_cfg self.bidirectional = bidirectional self.bidirectional_strategy = bidirectional_strategy self.bidirectional_weight_tie = bidirectional_weight_tie self.rcps = rcps self.complement_map = complement_map self.pos_embeddings = pos_embeddings self.row_first = row_first class MixedCaduceusConfig(PretrainedConfig): """Config that extends the original CaduceusConfig with params relevant to alternating between attention and caducues""" model_type = "mixed_caduceus" def __init__( self, # From original MambaConfig d_model: int = 2560, d_intermediate: int = 0, use_mamba2: bool = False, n_layer: int = 64, vocab_size: int = 50277, ssm_cfg: Optional[dict] = None, rms_norm: bool = True, residual_in_fp32: bool = True, fused_add_norm: bool = True, pad_vocab_size_multiple: int = 8, # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm norm_epsilon: float = 1e-5, # Used in init_weights initializer_cfg: Optional[dict] = None, # Caduceus-specific params bidirectional: bool = True, bidirectional_strategy: Union[str, None] = "add", bidirectional_weight_tie: bool = True, rcps: bool = False, complement_map: Optional[dict] = None, # used for RCPSEmbedding / RCPSLMHead # attention specific params attn_d_model: int = 128, attn_n_heads: int = 16, attn_attn_dropout: float = 0.1, attn_block_dropout: float = 0.1, **kwargs, ): super().__init__(**kwargs) self.d_model = d_model self.d_intermediate = d_intermediate self.use_mamba2 = use_mamba2 self.n_layer = n_layer self.vocab_size = vocab_size self.ssm_cfg = ssm_cfg self.rms_norm = rms_norm self.residual_in_fp32 = residual_in_fp32 self.fused_add_norm = fused_add_norm self.pad_vocab_size_multiple = pad_vocab_size_multiple self.norm_epsilon = norm_epsilon self.initializer_cfg = initializer_cfg self.bidirectional = bidirectional self.bidirectional_strategy = bidirectional_strategy self.bidirectional_weight_tie = bidirectional_weight_tie self.rcps = rcps self.complement_map = complement_map self.attn_d_model = attn_d_model self.attn_n_heads = attn_n_heads self.attn_attn_dropout = attn_attn_dropout self.attn_block_dropout = attn_block_dropout