caduceus_base / configuration_caduceus.py
yairschiff's picture
Upload Caduceus
5ddd9c9 verified
"""Caduceus config for Hugging Face.
"""
from typing import Optional, Union
from transformers import PretrainedConfig
class CaduceusConfig(PretrainedConfig):
"""Config that extends the original MambaConfig with params relevant to bi-directionality and RC equivariance."""
model_type = "caduceus"
def __init__(
self,
# From original MambaConfig
d_model: int = 2560,
n_layer: int = 64,
vocab_size: int = 50277,
ssm_cfg: Optional[dict] = None,
rms_norm: bool = True,
residual_in_fp32: bool = True,
fused_add_norm: bool = True,
pad_vocab_size_multiple: int = 8,
# Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm
norm_epsilon: float = 1e-5,
# Used in init_weights
initializer_cfg: Optional[dict] = None,
# Caduceus-specific params
bidirectional: bool = True,
bidirectional_strategy: Union[str, None] = "add",
bidirectional_weight_tie: bool = True,
rcps: bool = False,
complement_map: Optional[dict] = None, # used for RCPSEmbedding / RCPSLMHead
**kwargs,
):
super().__init__(**kwargs)
self.d_model = d_model
self.n_layer = n_layer
self.vocab_size = vocab_size
self.ssm_cfg = ssm_cfg
self.rms_norm = rms_norm
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.pad_vocab_size_multiple = pad_vocab_size_multiple
self.norm_epsilon = norm_epsilon
self.initializer_cfg = initializer_cfg
self.bidirectional = bidirectional
self.bidirectional_strategy = bidirectional_strategy
self.bidirectional_weight_tie = bidirectional_weight_tie
self.rcps = rcps
self.complement_map = complement_map