""" | |
Modeling module for Mamba models | |
""" | |
import importlib | |
def check_mamba_ssm_installed(): | |
mamba_ssm_spec = importlib.util.find_spec("mamba_ssm") | |
if mamba_ssm_spec is None: | |
raise ImportError( | |
"MambaLMHeadModel requires mamba_ssm. Please install it with `pip install -e .[mamba-ssm]`" | |
) | |
def fix_mamba_attn_for_loss(): | |
check_mamba_ssm_installed() | |
from mamba_ssm.models import mixer_seq_simple | |
from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed | |
mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed | |
return mixer_seq_simple.MambaLMHeadModel # pylint: disable=invalid-name | |