Commit
•
3f92511
1
Parent(s):
f59f64a
Update modeling_mos_mamba.py
Browse files- modeling_mos_mamba.py +18 -5
modeling_mos_mamba.py
CHANGED
@@ -26,22 +26,35 @@ from torch.nn import CrossEntropyLoss
|
|
26 |
from transformers.activations import ACT2FN
|
27 |
from transformers.modeling_utils import PreTrainedModel
|
28 |
from transformers.utils import ModelOutput
|
29 |
-
from transformers.utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available
|
30 |
from .configuration_mos_mamba import MoSMambaConfig
|
31 |
|
32 |
import torch.nn.functional as F
|
33 |
|
34 |
|
35 |
-
if is_mamba_ssm_available():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
|
37 |
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
38 |
-
|
39 |
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
|
40 |
|
41 |
-
|
42 |
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
43 |
-
|
44 |
causal_conv1d_update, causal_conv1d_fn = None, None
|
|
|
45 |
|
46 |
is_fast_path_available = all(
|
47 |
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
|
|
26 |
from transformers.activations import ACT2FN
|
27 |
from transformers.modeling_utils import PreTrainedModel
|
28 |
from transformers.utils import ModelOutput
|
29 |
+
# from transformers.utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available
|
30 |
from .configuration_mos_mamba import MoSMambaConfig
|
31 |
|
32 |
import torch.nn.functional as F
|
33 |
|
34 |
|
35 |
+
# if is_mamba_ssm_available():
|
36 |
+
# from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
|
37 |
+
# from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
38 |
+
# else:
|
39 |
+
# selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
|
40 |
+
|
41 |
+
# if is_causal_conv1d_available():
|
42 |
+
# from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
43 |
+
# else:
|
44 |
+
# causal_conv1d_update, causal_conv1d_fn = None, None
|
45 |
+
|
46 |
+
|
47 |
+
try:
|
48 |
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
|
49 |
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
50 |
+
except:
|
51 |
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
|
52 |
|
53 |
+
try:
|
54 |
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
55 |
+
except:
|
56 |
causal_conv1d_update, causal_conv1d_fn = None, None
|
57 |
+
|
58 |
|
59 |
is_fast_path_available = all(
|
60 |
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|