"""Caduceus config for Hugging Face. """ from typing import Optional, Union from transformers import PretrainedConfig class CaduceusConfig(PretrainedConfig): """Config that extends the original MambaConfig with params relevant to bi-directionality and RC equivariance.""" model_type = "caduceus" def __init__( self, # From original MambaConfig d_model: int = 2560, n_layer: int = 64, vocab_size: int = 50277, ssm_cfg: Optional[dict] = None, rms_norm: bool = True, residual_in_fp32: bool = True, fused_add_norm: bool = True, pad_vocab_size_multiple: int = 8, # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm norm_epsilon: float = 1e-5, # Used in init_weights initializer_cfg: Optional[dict] = None, # Caduceus-specific params bidirectional: bool = True, bidirectional_strategy: Union[str, None] = "add", bidirectional_weight_tie: bool = True, rcps: bool = False, complement_map: Optional[dict] = None, # used for RCPSEmbedding / RCPSLMHead **kwargs, ): super().__init__(**kwargs) self.d_model = d_model self.n_layer = n_layer self.vocab_size = vocab_size self.ssm_cfg = ssm_cfg self.rms_norm = rms_norm self.residual_in_fp32 = residual_in_fp32 self.fused_add_norm = fused_add_norm self.pad_vocab_size_multiple = pad_vocab_size_multiple self.norm_epsilon = norm_epsilon self.initializer_cfg = initializer_cfg self.bidirectional = bidirectional self.bidirectional_strategy = bidirectional_strategy self.bidirectional_weight_tie = bidirectional_weight_tie self.rcps = rcps self.complement_map = complement_map