File size: 656 Bytes
40a6362
 
 
 
d69ba2b
 
 
 
 
 
 
 
 
 
40a6362
 
d69ba2b
 
40a6362
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
"""
Modeling module for Mamba models
"""

import importlib


def check_mamba_ssm_installed():
    mamba_ssm_spec = importlib.util.find_spec("mamba_ssm")
    if mamba_ssm_spec is None:
        raise ImportError(
            "MambaLMHeadModel requires mamba_ssm. Please install it with `pip install -e .[mamba-ssm]`"
        )


def fix_mamba_attn_for_loss():
    check_mamba_ssm_installed()

    from mamba_ssm.models import mixer_seq_simple

    from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed

    mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed
    return mixer_seq_simple.MambaLMHeadModel  # pylint: disable=invalid-name